diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000..1e35e0c496 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 100 +extend-ignore = E203,E501,F401,E402,E714 +per-file-ignores = __init__.py:F401 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5955b349f1..fd081062eb 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,8 @@ build slurm* logs .vscode +local/ +.gitmodules +wandb/ +onelogger.log +onelogger.err \ No newline at end of file diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index b8b5423c13..3b2e4e1502 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,621 +1,141 @@ -image: nvcr.io/nvidia/pytorch:23.04-py3 +workflow: + rules: + - if: $CI_PROJECT_NAMESPACE != "ADLR" + when: never + - if: $CI_COMMIT_BRANCH =~ /ci-/ && $CI_PIPELINE_SOURCE != "schedule" + when: never + - if: $CI_PIPELINE_SOURCE == "schedule" + auto_cancel: + on_new_commit: none + - if: $CI_PIPELINE_SOURCE == "web" + - if: $CI_COMMIT_REF_PROTECTED == "true" + variables: + FUNCTIONAL_TEST: 'no' + - if: $CI_MERGE_REQUEST_LABELS =~ /Run tests/ && $CI_MERGE_REQUEST_TARGET_BRANCH_SHA != "" + variables: + UNIT_TEST_REPEAT: 1 + UNIT_TEST_TIMEOUT: 15 + FUNCTIONAL_TEST: 'yes' + FUNCTIONAL_TEST_SCOPE: mr + FUNCTIONAL_TEST_REPEAT: 5 + FUNCTIONAL_TEST_TIME_LIMIT: 2700 + FUNCTIONAL_TEST_CLUSTER_A100: '' + FUNCTIONAL_TEST_CLUSTER_H100: '' + PUBLISH: 'no' + - if: $CI_MERGE_REQUEST_LABELS =~ /Run nightly/ && $CI_MERGE_REQUEST_TARGET_BRANCH_SHA != "" + variables: + UNIT_TEST_REPEAT: 1 + UNIT_TEST_TIMEOUT: 15 + FUNCTIONAL_TEST: 'yes' + FUNCTIONAL_TEST_SCOPE: nightly + FUNCTIONAL_TEST_REPEAT: 5 + FUNCTIONAL_TEST_TIME_LIMIT: 2700 + FUNCTIONAL_TEST_CLUSTER_A100: '' + FUNCTIONAL_TEST_CLUSTER_H100: '' + PUBLISH: 'no' + - if: $CI_MERGE_REQUEST_LABELS =~ /Run weekly/ && $CI_MERGE_REQUEST_TARGET_BRANCH_SHA != "" + variables: + UNIT_TEST_REPEAT: 1 + UNIT_TEST_TIMEOUT: 15 + FUNCTIONAL_TEST: 'yes' + FUNCTIONAL_TEST_SCOPE: weekly + FUNCTIONAL_TEST_REPEAT: 1 + FUNCTIONAL_TEST_TIME_LIMIT: 9000 + FUNCTIONAL_TEST_CLUSTER_A100: '' + FUNCTIONAL_TEST_CLUSTER_H100: '' + PUBLISH: 'no' + - if: $CI_PIPELINE_SOURCE == "merge_request_event" && $CI_MERGE_REQUEST_TARGET_BRANCH_SHA != "" + variables: + FUNCTIONAL_TEST: 'no' + PUBLISH: 'no' + - when: never + auto_cancel: + on_new_commit: interruptible + # on_job_failure: all stages: - test - - cleanup - -variables: &VARS - SELENE_ADLR_CI_PATH: "/lustre/fsw/adlr/adlr-nlp/adlr_ci/megatron" - DATA_DIR: "/lustre/fsw/adlr/adlr-nlp/adlr_ci/megatron/data" - PYTORCH_IMAGE: /lustre/fsw/adlr/adlr-nlp/adlr_ci/megatron/nvcr_pytorch_23.04.sqsh # This is the image that is run by all nodes on selene for tests - PYTHON_VIRTUAL_ENV: /lustre/fsw/adlr/adlr-nlp/adlr_ci/cienv/bin/activate - TESTS_TO_RUN_AFTER_MERGE_REQ_APPROVED: MR_TESTS # Can specify levels - TESTS_TO_RUN_AFTER_MERGING: MR_TESTS NIGHTLY_TESTS # Can specify levels - TESTS_TO_RUN_ON_THIS_COMMIT: unit_tests - TEST_REGEX_ON_THIS_COMMIT: NONE #https://github.com/google/re2/wiki/Syntax (Can define regex as in this spec) e.g /.*gpt3.*/ - DISPLAY_OUTPUT: "True" # Set to true for new tests to copy the logs for creating golden truth file - TIME_LIMIT: "10:00" # Default time limit for all jobs - -unit_tests: - image: nvcr.io/nvidia/pytorch:23.04-py3 - tags: - - docker_local_runner - stage: test - script: - - pip install pytest-cov - - pip install pytest_mock - - pip install nltk - - pip install zarr "tensorstore==0.1.45" # for distributed checkpointing tests - - torchrun --nproc_per_node=8 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/unit_tests - coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' - artifacts: - paths: - - coverage - expire_in: 30 days - rules: - - when: always - -formatting: - image: nvcr.io/nvidia/pytorch:23.04-py3 - tags: - - docker_local_runner - stage: test - script: - - pip install --upgrade black==19.10b0 isort click==8.0.2 - - black megatron/core --check --verbose --diff - - isort megatron/core --check - rules: - - when: always - -.selene_test_resume_checkpoint_launcher: &selene-test-resume-checkpoint-launcher - tags: - - ssh_selene_runner - stage: test - script: &selene-test-resume-launcher-script - - echo "Running selene resume from checkpoint test. " - - pwd - - run_cmd="bash tests/functional_tests/shell_test_utils/run_selene_test_resume_checkpoint_launcher_script.sh RUN_MODEL=$RUN_MODEL TP_SIZE=$TP_SIZE PP_SIZE=$PP_SIZE VP_SIZE=$VP_SIZE NUM_NODES=$NUM_NODES SELENE_ADLR_CI_PATH=$SELENE_ADLR_CI_PATH CI_PIPELINE_ID=$CI_PIPELINE_ID RUN_NAME=$RUN_NAME PYTORCH_IMAGE=$PYTORCH_IMAGE DATA_DIR=$DATA_DIR TIME_LIMIT=$TIME_LIMIT" - - echo "$run_cmd" - - ${run_cmd} - - echo "Completed the job" - rules: - - if: $TEST_LEVEL =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TEST_REGEX_ON_THIS_COMMIT - when: always - - if: '$CI_COMMIT_REF_NAME == $CI_DEFAULT_BRANCH && $TEST_LEVEL =~ $TESTS_TO_RUN_AFTER_MERGING' - when: always - - if: $CI_MERGE_REQUEST_APPROVED && $TEST_LEVEL =~ $TESTS_TO_RUN_AFTER_MERGE_REQ_APPROVED - when: always - allow_failure: false - retry: 2 - -.selene_test_launcher: &selene-test-launcher - tags: - - ssh_selene_runner - stage: test - script: &selene-test-launcher-script - - echo "Running selene test" - - pwd - - run_cmd="bash tests/functional_tests/shell_test_utils/run_selene_test_launcher_script.sh RUN_MODEL=$RUN_MODEL TP_SIZE=$TP_SIZE PP_SIZE=$PP_SIZE VP_SIZE=$VP_SIZE NUM_NODES=$NUM_NODES SELENE_ADLR_CI_PATH=$SELENE_ADLR_CI_PATH CI_PIPELINE_ID=$CI_PIPELINE_ID RUN_NAME=$RUN_NAME MAX_STEPS=$MAX_STEPS PYTORCH_IMAGE=$PYTORCH_IMAGE DATA_DIR=$DATA_DIR USE_CORE=$USE_CORE USE_TE=$USE_TE TIME_LIMIT=$TIME_LIMIT" - - echo "$run_cmd" - - ${run_cmd} - - echo "Completed the job" - rules: - - if: $TEST_LEVEL =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TEST_REGEX_ON_THIS_COMMIT - when: always - - if: '$CI_COMMIT_REF_NAME == $CI_DEFAULT_BRANCH && $TEST_LEVEL =~ $TESTS_TO_RUN_AFTER_MERGING' - when: always - - if: $CI_MERGE_REQUEST_APPROVED && $TEST_LEVEL =~ $TESTS_TO_RUN_AFTER_MERGE_REQ_APPROVED - when: always - allow_failure: false - retry: 2 - -train.te_gpt3.345m_tp2_pp2_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 1 - TP_SIZE: 2 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - TIME_LIMIT: "20:00" - TEST_LEVEL: MR_TESTS - PYTORCH_IMAGE: nvcr.io/nvidia/pytorch:23.07-py3 - -train.gpt3_core.345m_tp4_pp1_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 4 - PP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: NIGHTLY_TESTS - -train.gpt3_core.345m_tp2_pp2_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 2 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: MR_TESTS - -train.gpt3_core.345m_tp1_pp2_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TIME_LIMIT: "10:00" - TEST_LEVEL: NIGHTLY_TESTS - -train.gpt3_core.345m_tp1_pp4_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 4 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: NIGHTLY_TESTS - -train.gpt3_core.345m_tp1_pp4_interleaved_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 4 - VP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: MR_TESTS - -train.gpt3_core.345m_tp1_pp2_1node_50steps_rope: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: MR_TESTS - METADATA: rope_embeddings - ADDITIONAL_PARAMS: "--position-embedding-type rope" - -train.gpt3_core.345m_tp1_pp4_1node_50steps_swiglu: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 4 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: MR_TESTS - METADATA: swiglu - ADDITIONAL_PARAMS: "--swiglu" - -train.gpt3_core.345m_tp1_pp4_1node_50steps_disable_bias_linear: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 4 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: MR_TESTS - METADATA: disable_bias_linear - ADDITIONAL_PARAMS: "--disable-bias-linear" - -train.gpt3_core.345m_tp1_pp4_1node_50steps_untie_embeddings_and_outputs: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 4 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: MR_TESTS - METADATA: untie_embeddings_and_outputs - ADDITIONAL_PARAMS: "--untie-embeddings-and-output-weights" - -train.gpt3_core.345m_tp1_pp4_1node_50steps_sequence_parallel: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 4 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: MR_TESTS - METADATA: sequence_parallel - ADDITIONAL_PARAMS: "--sequence-parallel" - -train.gpt3.345m_tp4_pp1_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 4 - PP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: NIGHTLY_TESTS - -train.gpt3.345m_tp2_pp2_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 2 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: MR_TESTS - -train.gpt3.345m_tp1_pp2_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: NIGHTLY_TESTS - -train.gpt3.345m_tp1_pp4_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 4 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: NIGHTLY_TESTS - -train.gpt3.345m_tp1_pp4_interleaved_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 4 - VP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: MR_TESTS - -resume.checkpoint.gpt3.345m_tp1_pp2_1node: - <<: *selene-test-resume-checkpoint-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - TP_SIZE: 1 - PP_SIZE: 2 - NUM_NODES: 1 - TIME_LIMIT: "15:00" - TEST_LEVEL: MR_TESTS - -train.gpt3.345m_tp1_pp1_1node_50steps_dist_optimizer: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: MR_TESTS - METADATA: dist_optimizer - ADDITIONAL_PARAMS: "--use-distributed-optimizer" - -train.gpt3.345m_tp1_pp1_1node_50steps_overlap_grad_reduce: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: NIGHTLY_TESTS - METADATA: overlap_grad_reduce - ADDITIONAL_PARAMS: "--overlap-grad-reduce" - -train.gpt3.345m_tp1_pp1_1node_50steps_dist_optimizer_overlap_grad_reduce: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: NIGHTLY_TESTS - METADATA: dist_optimizer_overlap_grad_reduce - ADDITIONAL_PARAMS: "--use-distributed-optimizer --overlap-grad-reduce" - -train.gpt3.345m_tp4_pp1_1node_50steps_overlap_grad_reduce: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 4 - PP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: NIGHTLY_TESTS - METADATA: overlap_grad_reduce - ADDITIONAL_PARAMS: "--overlap-grad-reduce" - -train.gpt3.345m_tp4_pp1_1node_50steps_dist_optimizer_overlap_grad_reduce: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 4 - PP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: MR_TESTS - METADATA: dist_optimizer_overlap_grad_reduce - ADDITIONAL_PARAMS: "--use-distributed-optimizer --overlap-grad-reduce" - -train.gpt3.345m_tp1_pp4_1node_50steps_overlap_grad_reduce: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 4 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: NIGHTLY_TESTS - METADATA: overlap_grad_reduce - ADDITIONAL_PARAMS: "--overlap-grad-reduce" - -train.gpt3.345m_tp1_pp4_interleaved_1node_50steps_overlap_grad_reduce: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 4 - VP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: NIGHTLY_TESTS - METADATA: overlap_grad_reduce - ADDITIONAL_PARAMS: "--overlap-grad-reduce" - -train.gpt3.345m_tp1_pp4_interleaved_1node_50steps_dist_optimizer_overlap_grad_reduce: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 1 - PP_SIZE: 4 - VP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: MR_TESTS - METADATA: dist_optimizer_overlap_grad_reduce - ADDITIONAL_PARAMS: "--use-distributed-optimizer --overlap-grad-reduce" - -train.gpt3.345m_tp2_pp2_1node_50steps_overlap_grad_reduce: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 2 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: NIGHTLY_TESTS - METADATA: overlap_grad_reduce - ADDITIONAL_PARAMS: "--overlap-grad-reduce" - -train.gpt3_core.345m_cp2_tp2_pp1_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 2 - PP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TIME_LIMIT: "20:00" - TEST_LEVEL: MR_TESTS - METADATA: "context_parallelism_cp2" - PYTORCH_IMAGE: "/lustre/fsw/adlr/adlr-nlp/adlr_ci/megatron/pytorch_23.10_flash_attn_1.0.9_context_parallelism.sqsh" - ADDITIONAL_PARAMS: "--context-parallel-size 2 --sequence-parallel --hidden-dropout 0.0 --attention-dropout 0.0" - -# Note: Core MoE models currently will run TE by default -train.te_core_moe_gpt3.345m_tp2_pp2_2experts_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 2 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: NIGHTLY_TESTS - METADATA: "te_2experts" - ADDITIONAL_PARAMS: "--num-experts 2" - -train.te_core_moe_gpt3.345m_tp2_pp2_4experts2parallel_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 2 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: NIGHTLY_TESTS - METADATA: "te_4experts2parallel" - ADDITIONAL_PARAMS: "--sequence-parallel --num-experts 4 --expert-model-parallel-size 2" - -train.te_core_moe_gpt3.345m_tp2_pp1_4experts2parallel_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 2 - PP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 1 - TEST_LEVEL: MR_TESTS - METADATA: "te_8experts2parallel" - ADDITIONAL_PARAMS: "--sequence-parallel --num-experts 8 --expert-model-parallel-size 2" - -train.moe_gpt3.345m_tp2_pp2_4experts_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: gpt3 - USE_TE: 0 - TP_SIZE: 2 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - USE_CORE: 0 - TEST_LEVEL: NIGHTLY_TESTS - METADATA: "4experts" - ADDITIONAL_PARAMS: "--num-experts 4" - -train.bert.345m_tp4_pp1_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: bert - TP_SIZE: 4 - PP_SIZE: 1 - NUM_NODES: 1 - MAX_STEPS: 50 - TIME_LIMIT: "10:00" - TEST_LEVEL: NIGHTLY_TESTS - -train.bert.345m_tp2_pp2_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: bert - TP_SIZE: 2 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - TEST_LEVEL: MR_TESTS - -train.bert.345m_tp1_pp2_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: bert - TP_SIZE: 1 - PP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - TEST_LEVEL: NIGHTLY_TESTS - -train.bert.345m_tp1_pp4_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: bert - TP_SIZE: 1 - PP_SIZE: 4 - NUM_NODES: 1 - MAX_STEPS: 50 - TEST_LEVEL: NIGHTLY_TESTS - -train.bert.345m_tp1_pp4_interleaved_1node_50steps: - <<: *selene-test-launcher - variables: - <<: [*VARS] - RUN_MODEL: bert - TP_SIZE: 1 - PP_SIZE: 4 - VP_SIZE: 2 - NUM_NODES: 1 - MAX_STEPS: 50 - TEST_LEVEL: MR_TESTS - -resume.checkpoint.bert.345m_tp1_pp2_1node: - <<: *selene-test-resume-checkpoint-launcher - variables: - <<: [*VARS] - RUN_MODEL: bert - TP_SIZE: 1 - PP_SIZE: 2 - NUM_NODES: 1 - TEST_LEVEL: MR_TESTS - -cleanup.selene: - tags: - - ssh_selene_runner - stage: cleanup - variables: - <<: [*VARS] - script: - - set +e - - NUM_CLEANUP=`find ${SELENE_ADLR_CI_PATH}/* -type d -ctime +20 | grep -v data | wc -l` - - find ${SELENE_ADLR_CI_PATH}/* -type d -ctime +20 | grep -v data | xargs rm -rf - - find ${SELENE_ADLR_CI_PATH}/* -type d -name "checkpoints" -ctime +2 | grep -v data | xargs rm -rf - - echo "Finished cleaning $NUM_CLEANUP directories older than 20 days everything in Selene" - allow_failure: true - rules: - - when: always + - functional_tests + - publish + +default: + interruptible: true + retry: + max: 2 + when: runner_system_failure + +variables: + UNIT_TEST: + value: 'yes' + options: + - 'yes' + - 'no' + description: To run the funtional test suite + UNIT_TEST_REPEAT: + value: '1' + description: 'Number of repetitions' + UNIT_TEST_TIMEOUT: + value: '30' + description: Timeout (minutes) for Unit tests (all repeats) + FUNCTIONAL_TEST: + value: 'yes' + options: + - 'yes' + - 'no' + description: To run the funtional test suite + FUNCTIONAL_TEST_SCOPE: + value: 'mr' + options: + - 'mr' + - 'nightly' + - 'weekly' + - 'pre-release' + - 'release' + description: 'Testsuite to run (only for FUNCTIONAL_TEST=yes)' + FUNCTIONAL_TEST_REPEAT: + value: '5' + description: 'Number of repetitions per test' + FUNCTIONAL_TEST_TIME_LIMIT: + value: '2700' + description: 'Timeout in seconds per test' + FUNCTIONAL_TEST_CASES: + value: 'all' + description: "Comma-separated list of test_cases to run. Use 'all' to run the full suite." + FUNCTIONAL_TEST_CLUSTER_A100: + value: 'dgxa100_dracooci' + options: + - 'dgxa100_dracooci' + - 'dgxa100_dracooci-ord' + description: 'Cluster for A100 workloads' + FUNCTIONAL_TEST_CLUSTER_H100: + value: 'dgxh100_eos' + options: + - 'dgxh100_coreweave' + - 'dgxh100_eos' + description: 'Cluster for H100 workloads' + FUNCTIONAL_TEST_NAME: + description: 'Name of functional test run (only for pre-release and release)' + PUBLISH: + value: 'no' + options: + - 'yes' + - 'no' + description: Build and publish a wheel to PyPi + PUBLISH_SCOPE: + value: 'code-freeze' + options: + - 'code-freeze' + - 'release' + description: Type of publish (freeze or final release) + + # CI wide variables + CI_MCORE_LTS_IMAGE: ${GITLAB_ENDPOINT}:5005/adlr/megatron-lm/mcore_ci_lts + CI_MCORE_DEV_IMAGE: ${GITLAB_ENDPOINT}:5005/adlr/megatron-lm/mcore_ci_dev + CI_NEMO_IMAGE: ${GITLAB_ENDPOINT}:5005/adlr/megatron-lm/nemo_ci + UTILITY_IMAGE: ${GITLAB_ENDPOINT}:5005/adlr/megatron-lm/mcore_utility + +include: + - .gitlab/stages/00.pre.yml + - .gitlab/stages/01.test.yml + - .gitlab/stages/02.functional-tests.yml + - .gitlab/stages/03.publish.yml diff --git a/.gitlab/labeler-config.yml b/.gitlab/labeler-config.yml new file mode 100644 index 0000000000..3dc4001cd7 --- /dev/null +++ b/.gitlab/labeler-config.yml @@ -0,0 +1,33 @@ +CI: +- .gitlab-ci.yml +- Dockerfile.ci.lts +- Dockerfile.ci.dev +- .github/** +- .gitlab/** + +Datasets: +- megatron/core/datasets/** + +BERT: +- megatron/core/models/bert/** + +GPT: +- megatron/core/models/gpt/** + +RETRO: +- megatron/core/models/retro/** + +Dist-Ckpt: +- megatron/core/dist_checkpointing + +Dist-Opt: +- megatron/core/optimizer/distrib_optimizer + +Inference: +- megatron/core/inference + +MoE: +- megatron/core/transformer/moe + +Tests: +- tests/** \ No newline at end of file diff --git a/.gitlab/stages/00.pre.yml b/.gitlab/stages/00.pre.yml new file mode 100644 index 0000000000..b5af2eeb88 --- /dev/null +++ b/.gitlab/stages/00.pre.yml @@ -0,0 +1,199 @@ +include: + - template: Security/Secret-Detection.gitlab-ci.yml + +.pre_rules: + rules: + - if: $CI_PIPELINE_SOURCE == 'merge_request_event' && $CI_MERGE_REQUEST_TARGET_BRANCH_PROTECTED != "true" + allow_failure: true + when: always + - if: $CI_PIPELINE_SOURCE == 'merge_request_event' + - when: never + stage: .pre + +.dind_rules: + image: docker:26.1.4-dind + variables: + DOCKER_HOST: unix:///var/run/docker.sock + before_script: + - docker system prune -a --filter "until=36h" -f || true + - echo "$NGC_API_KEY" | docker login nvcr.io -u '$oauthtoken' --password-stdin + - echo "$CI_REGISTRY_PASSWORD" | docker login $CI_REGISTRY -u $CI_REGISTRY_USER --password-stdin + +pre:mirror_to_github: + rules: + - if: '$CI_COMMIT_REF_PROTECTED == "true" && $CI_PIPELINE_SOURCE == "push"' + - when: never + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + stage: .pre + image: python:3.10 + variables: + GIT_STRATEGY: 'clone' + script: + - git checkout $CI_COMMIT_BRANCH + - git remote add github https://ko3n1g:$GH_TOKEN@github.com/NVIDIA/Megatron-LM.git || true + - git push -u github $CI_COMMIT_BRANCH + +pre:create_ci_branches: + rules: + - if: '$CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH && $CI_PIPELINE_SOURCE == "push"' + - when: never + parallel: + matrix: + - branch: ci-unit-test-extended + - branch: ci-rebuild-mcore-nemo-image + - branch: ci-mr + - branch: ci-nightly + - branch: ci-weekly + - branch: ci-pre-release + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + stage: .pre + image: python:3.10 + variables: + GIT_STRATEGY: 'clone' + script: + - git remote set-url origin "https://gitlab-ci-token:${PROJECT_ACCESS_TOKEN_MCORE}@${GITLAB_ENDPOINT}/adlr/megatron-lm.git" + - git switch --force-create $branch + - git push --force -u origin $branch + +pre:label_merge_request: + extends: [.pre_rules] + image: golang:1.22 + tags: + - mcore-docker-node-small + before_script: + - git clone -b nv https://${GITLAB_ENDPOINT}/okoenig/gitlab-mr-labeler.git + - cd gitlab-mr-labeler + - go install . + - cd .. + - go install github.com/itchyny/gojq/cmd/gojq@latest + - | + echo LABELS=$(curl --header "PRIVATE-TOKEN: ${PROJECT_ACCESS_TOKEN_MCORE}" --url "https://${GITLAB_ENDPOINT}/api/v4/projects/${CI_PROJECT_ID}/merge_requests/${CI_MERGE_REQUEST_IID}" | gojq '.labels | join(",")') > labels + script: + - gitlab-mr-labeler -f .gitlab/labeler-config.yml -t ${PROJECT_ACCESS_TOKEN_MCORE} --debug true + after_script: + - | + source labels + curl --header "PRIVATE-TOKEN: ${PROJECT_ACCESS_TOKEN_MCORE}" --url "https://${GITLAB_ENDPOINT}/api/v4/projects/${CI_PROJECT_ID}/merge_requests/${CI_MERGE_REQUEST_IID}" --data-urlencode "add_labels=$LABELS" -X PUT + +pre:maybe_cherry_pick_commit: + rules: + - if: '$CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH && $CI_PIPELINE_SOURCE == "push"' + - when: never + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + stage: .pre + image: nentangso/alpine-git-curl-jq + variables: + GIT_STRATEGY: 'clone' + script: + - set -x + - set +e + - SHA=$(git rev-list --no-merges -n 1 HEAD) + - MESSAGE=$(git log -n 1 --pretty=format:%s $SHA) + - MR_ID=$(echo $MESSAGE | awk -F'!' '{print $2}' | awk '{print $1}' ) + - git remote set-url origin "https://gitlab-ci-token:${PROJECT_ACCESS_TOKEN_MCORE}@${GITLAB_ENDPOINT}/$CI_PROJECT_NAMESPACE/megatron-lm.git" + - git config --global user.email "mcore-bot@nvidia.com" + - git config --global user.name "Mcore Bot" + - | + MR=$(curl --header "PRIVATE-TOKEN: ${PROJECT_ACCESS_TOKEN_MCORE}" --url "https://${GITLAB_ENDPOINT}/api/v4/projects/${CI_PROJECT_ID}/merge_requests/${MR_ID}") + + LABELS=$(echo -E $MR | jq '.labels | join(",")' | tr -d '"') + AUTHOR_ID=$(echo -E $MR | jq '.author.id' | tr -d '"') + AUTHOR_NAME=$(echo -E $MR | jq '.author.username' | tr -d '"') + TITLE=$(echo -E $MR | jq '.title' | tr -d '"') + MILESTONE_ID=$(echo -E $MR | jq '.milestone.id' | tr -d '"') + TARGET_BRANCHES=$(echo "$LABELS" | grep -o 'core_[^,]*') + + if [[ $TARGET_BRANCHES == "" ]]; then + echo Nothing to cherry pick + exit 0 + fi + + echo $TARGET_BRANCHES | while read -r RELEASE_BRANCH ; do + TARGET_BRANCH_EXISTS_OK=$([[ "$(git ls-remote --heads origin refs/heads/$RELEASE_BRANCH)" != "" ]] && echo true || echo false) + + if [[ "$TARGET_BRANCH_EXISTS_OK" == "false" ]]; then + echo Release branch does not yet exist, will not cherry-pick + continue + fi + + ( + git fetch origin $RELEASE_BRANCH:$RELEASE_BRANCH + git switch --force-create cherry-pick-$MR_ID-$RELEASE_BRANCH $RELEASE_BRANCH + git cherry-pick $SHA + git push -u origin --force cherry-pick-$MR_ID-$RELEASE_BRANCH + git checkout ${CI_DEFAULT_BRANCH:-main} + ) + + CHERRYPICK_SUCCESSFUL=$? + + if [[ $CHERRYPICK_SUCCESSFUL -eq 0 ]]; then + curl \ + --header "PRIVATE-TOKEN: $PAT" \ + --url https://${GITLAB_ENDPOINT}/api/v4/projects/${CI_PROJECT_ID}/merge_requests \ + -d "source_branch=cherry-pick-$MR_ID-$RELEASE_BRANCH" \ + -d "target_branch=$RELEASE_BRANCH" \ + -d "title=Cherry pick \`$TITLE ($MR_ID)\` into \`$RELEASE_BRANCH\`" \ + -d "labels=cherry-pick" \ + -d "reviewer_ids=$AUTHOR_ID" \ + -d "milestone_id=$MILESTONE_ID" \ + -d "description=[🤖]: Hi @$AUTHOR_NAME 👋,

we've cherry picked \`$TITLE ($MR_ID)\` into \`$RELEASE_BRANCH\` for you! 🚀

Please review and approve this cherry pick by your convenience\!" + + else + URL=https://${GITLAB_ENDPOINT}/ADLR/megatron-lm/-/merge_requests/$MR_ID + + MESSAGE='{ + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "beep boop 🤖: Cherry-pick of <'$URL'|!'$MR_ID'> failed\ncc '$SLACK_ADMIN'" + } + } + ] + }' + + curl -X POST -H "Content-type: application/json" --data "$MESSAGE" ${MCORE_NOTIFICATION_HOOK} + + fi + + done + interruptible: false + +pre:check_milestone: + extends: [.pre_rules] + image: badouralix/curl-jq + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + script: + - env + - | + MILESTONE=$(curl --header "PRIVATE-TOKEN: ${PROJECT_ACCESS_TOKEN_MCORE}" --url "https://${GITLAB_ENDPOINT}/api/v4/projects/${CI_PROJECT_ID}/merge_requests/${CI_MERGE_REQUEST_IID}" | jq '.milestone') + - | + if [[ "$MILESTONE" == "null" ]]; then + echo Please assign a Milestone to this MR! + exit 1 + fi diff --git a/.gitlab/stages/01.test.yml b/.gitlab/stages/01.test.yml new file mode 100644 index 0000000000..e2ccf40ed1 --- /dev/null +++ b/.gitlab/stages/01.test.yml @@ -0,0 +1,581 @@ +.test_rules: + rules: + - if: $UNIT_TEST == 'yes' && $CI_PIPELINE_SOURCE == 'merge_request_event' && $CI_MERGE_REQUEST_TARGET_BRANCH_PROTECTED != "true" + allow_failure: true + when: on_success + - when: on_success + stage: test + +include: + - template: Security/Secret-Detection.gitlab-ci.yml + +test:build_image: + extends: [.test_rules, .dind_rules] + tags: + - arch/amd64 + - origin/jet-fleet + - env/prod + - ${TAG} + services: + - name: docker:24.0.5-dind + variables: + HEALTHCHECK_TCP_PORT: '2376' + timeout: 45m + parallel: + matrix: + - IMAGE: CI_MCORE_LTS_IMAGE + FILE: Dockerfile.ci.lts + BASE_IMAGE: nvcr.io/nvidia/pytorch:24.01-py3 + - IMAGE: CI_MCORE_DEV_IMAGE + FILE: Dockerfile.ci.dev + BASE_IMAGE: nvcr.io/nvidia/pytorch:24.10-py3 + - IMAGE: CI_NEMO_IMAGE + FILE: Dockerfile.ci.dev + BASE_IMAGE: nvcr.io/nvidian/nemo:nightly + - IMAGE: UTILITY_IMAGE + FILE: Dockerfile.linting + BASE_IMAGE: python:3.10 + variables: + DOCKER_HOST: tcp://docker:2376 + DOCKER_TLS_CERTDIR: '/certs' + DOCKER_TLS_VERIFY: 1 + DOCKER_CERT_PATH: '$DOCKER_TLS_CERTDIR/client' + TAG: purpose/builder-large + STAGE: jet + MCORE_BACKWARDS_REF: core_r0.11.0 + script: + - apk add bash + - | + bash -c ' + set -x + env + eval "IMAGE=\$$IMAGE" + + docker context create tls-environment + docker buildx create --name container --driver=docker-container --use tls-environment + + ADDITIONAL_PARAMS=() + + if [[ "$CI_COMMIT_BRANCH" == "ci-rebuild-mcore-nemo-image" || "$CI_COMMIT_BRANCH" == "main" ]]; then + ADDITIONAL_PARAMS+=("--pull") + ADDITIONAL_PARAMS+=("--cache-to type=registry,ref=${IMAGE}-buildcache:main") + fi + + if [[ "$CI_COMMIT_BRANCH" == "ci-nightly-a100" ]]; then + ADDITIONAL_PARAMS+=("-t ${IMAGE}:nightly") + fi + + echo $(git rev-parse HEAD) + + DOCKER_BUILDKIT=1 docker build \ + --secret id=JET_INDEX_URLS \ + --secret id=LOGGER_INDEX_URL \ + --target $STAGE \ + -f $FILE \ + -t ${IMAGE}:${CI_PIPELINE_ID} \ + --builder=container \ + --build-arg CACHEBUST=$(cat /proc/sys/kernel/random/uuid) \ + --build-arg MCORE_REPO=${CI_REPOSITORY_URL} \ + --build-arg MCORE_REF=$CI_COMMIT_SHA \ + --build-arg MCORE_BACKWARDS_REF=$MCORE_BACKWARDS_REF \ + --cache-to type=registry,ref=${IMAGE}-buildcache:${CI_PIPELINE_ID} \ + --cache-to type=registry,ref=${IMAGE}-buildcache:${CI_MERGE_REQUEST_IID:-noop} \ + --cache-from type=registry,ref=${IMAGE}-buildcache:main \ + --cache-from type=registry,ref=${IMAGE}-buildcache:${CI_PIPELINE_ID} \ + --cache-from type=registry,ref=${IMAGE}-buildcache:${CI_MERGE_REQUEST_IID:-noop} \ + --build-arg FROM_IMAGE_NAME=$BASE_IMAGE \ + --push \ + ${ADDITIONAL_PARAMS[@]} . + ' + retry: + max: 2 + +test:unit_tests_configure: + extends: [.test_rules] + needs: + - test:build_image + image: ${UTILITY_IMAGE}:${CI_PIPELINE_ID} + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + before_script: + - git rm -r tests/test_utils/local_recipes || true + - git submodule add --force https://gitlab-ci-token:${CI_JOB_TOKEN}@${GITLAB_ENDPOINT}/ADLR/megatron-lm-convergence-tests.git tests/test_utils/local_recipes + - ls tests/test_utils/local_recipes + script: + - set -x + - | + A100_CLUSTER=$([[ "$FUNCTIONAL_TEST_CLUSTER_A100" != "" ]] && echo $FUNCTIONAL_TEST_CLUSTER_A100 || echo $DEFAULT_A100_CLUSTER) + H100_CLUSTER=$([[ "$FUNCTIONAL_TEST_CLUSTER_H100" != "" ]] && echo $FUNCTIONAL_TEST_CLUSTER_H100 || echo $DEFAULT_H100_CLUSTER) + - | + ARGS=( + "--scope unit-tests" + "--n-repeat ${UNIT_TEST_REPEAT}" + "--time-limit $(( UNIT_TEST_TIMEOUT * 60 ))" + "--test-cases all" + "--a100-cluster dgxa100_dracooci-ord" + "--h100-cluster dgxh100_coreweave" + "--h100-partition batch_short,batch" + "--container-image ${UTILITY_IMAGE}" + "--container-tag ${CI_PIPELINE_ID}" + "--dependent-job test:unit_tests_configure" + ) + - | + export PYTHONPATH=$(pwd) + python tests/test_utils/python_scripts/generate_jet_trigger_job.py \ + ${ARGS[@]} \ + --environment "lts" \ + --tag "legacy" \ + --output-path "unit-test-job-lts-legacy.yaml" + - | + export PYTHONPATH=$(pwd) + python tests/test_utils/python_scripts/generate_jet_trigger_job.py \ + ${ARGS[@]} \ + --environment "lts" \ + --tag "latest" \ + --output-path "unit-test-job-lts-latest.yaml" + - | + export PYTHONPATH=$(pwd) + python tests/test_utils/python_scripts/generate_jet_trigger_job.py \ + ${ARGS[@]} \ + --environment "dev" \ + --tag "legacy" \ + --output-path "unit-test-job-dev-legacy.yaml" + - | + export PYTHONPATH=$(pwd) + python tests/test_utils/python_scripts/generate_jet_trigger_job.py \ + ${ARGS[@]} \ + --environment "dev" \ + --tag "latest" \ + --output-path "unit-test-job-dev-latest.yaml" + rules: + - if: $UNIT_TEST == 'yes' && $CI_PIPELINE_SOURCE == 'merge_request_event' && $CI_MERGE_REQUEST_TARGET_BRANCH_PROTECTED != "true" + allow_failure: true + when: on_success + - if: $UNIT_TEST == 'yes' && $UNIT_TEST_REPEAT != '0' + when: on_success + artifacts: + paths: + - unit-test-job-dev-legacy.yaml + - unit-test-job-dev-latest.yaml + - unit-test-job-lts-legacy.yaml + - unit-test-job-lts-latest.yaml + - tests/test_utils/local_recipes + +.unit_tests_run: + needs: + - test:formatting + - test:copyright + - job: test:secret_detection + optional: true + - test:unit_tests_configure + extends: [.test_rules] + trigger: + include: + - artifact: unit-test-job-$ENVIRONMENT-$TAG.yaml + job: test:unit_tests_configure + strategy: depend + variables: + RO_API_TOKEN: $PAT + CONTAINER_TAG: $CI_PIPELINE_ID + CI_MCORE_LTS_IMAGE: $CI_MCORE_LTS_IMAGE + GITLAB_ENDPOINT: $GITLAB_ENDPOINT + PARENT_PIPELINE_ID: $CI_PIPELINE_ID + inherit: + variables: true + rules: + - if: $UNIT_TEST == 'yes' && $CI_PIPELINE_SOURCE == 'merge_request_event' && $CI_MERGE_REQUEST_TARGET_BRANCH_PROTECTED != "true" + allow_failure: true + when: on_success + - if: $UNIT_TEST == 'yes' && $UNIT_TEST_REPEAT != '0' + when: on_success + +test:unit_tests_pyt(DEV)_mcore(legacy): + extends: [.unit_tests_run] + variables: + ENVIRONMENT: dev + TAG: legacy + +test:unit_tests_pyt(LTS)_mcore(legacy): + extends: [.unit_tests_run] + variables: + ENVIRONMENT: lts + TAG: legacy + +test:unit_tests_pyt(DEV)_mcore(latest): + extends: [.unit_tests_run] + variables: + ENVIRONMENT: dev + TAG: latest + +test:unit_tests_pyt(LTS)_mcore(latest): + extends: [.unit_tests_run] + variables: + ENVIRONMENT: lts + TAG: latest + +test:notify_unit_tests: + extends: [.test_rules] + image: badouralix/curl-jq + needs: + - test:unit_tests_pyt(DEV)_mcore(latest) + - test:unit_tests_pyt(LTS)_mcore(latest) + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + script: + - apk add bash + - apk add --update coreutils + - env + - export WEBHOOK_URL=${MCORE_NOTIFICATION_HOOK} + - export RO_API_TOKEN=${PROJECT_ACCESS_TOKEN_MCORE} + - export GITLAB_ENDPOINT + - export CONTEXT="unit-tests-extended" + - export DATE=$(date +"%Y-%m-%d") + - bash tests/test_utils/shell_scripts/notify.sh ${CI_PIPELINE_ID} "test:unit_tests_pyt" + artifacts: + when: always + paths: + - scripts + rules: + - if: $CI_PIPELINE_SOURCE == "schedule" && $CI_COMMIT_BRANCH == "ci-unit-test-extended" + when: always + - when: never + +test:docs_build: + extends: [.test_rules] + image: ${UTILITY_IMAGE}:${CI_PIPELINE_ID} + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + needs: [test:build_image] + script: + - cd .. + - rm -rf documentation && git clone https://gitlab-ci-token:${CI_JOB_TOKEN}@${GITLAB_ENDPOINT}/nemo-megatron-core-tme/documentation.git + - mv megatron-lm/ documentation/ + - cd documentation/ + - ./repo docs + +test:formatting: + extends: [.test_rules] + image: ${UTILITY_IMAGE}:${CI_PIPELINE_ID} + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + needs: [test:build_image] + variables: + GIT_STRATEGY: 'clone' + script: + - | + if [[ "$CI_PIPELINE_SOURCE" != "merge_request_event" ]]; then + exit 0 + fi + - set +e + - git fetch origin main:main + - | + if [[ "$CI_MERGE_REQUEST_PROJECT_PATH" == "$CI_MERGE_REQUEST_SOURCE_PROJECT_PATH" ]]; then + bash tools/autoformat.sh + set -e + git fetch origin $CI_MERGE_REQUEST_SOURCE_BRANCH_NAME + git checkout $CI_MERGE_REQUEST_SOURCE_BRANCH_NAME + git config --global user.email "mcore-bot@nvidia.com" + git config --global user.name "Mcore Bot" + git remote set-url origin "https://gitlab-ci-token:${PAT}@${GITLAB_ENDPOINT}/$CI_PROJECT_NAMESPACE/megatron-lm.git" + git add -A . + git commit -m "chore: Format files" || true + git push -u origin $CI_MERGE_REQUEST_SOURCE_BRANCH_NAME + fi + - env + - BASE_REF="$CI_MERGE_REQUEST_TARGET_BRANCH_NAME" CHECK_ONLY=true SKIP_DOCS=$([[ "$CI_MERGE_REQUEST_LABELS" == *"Skip docs"* ]] && echo "true" || echo "false") bash tools/autoformat.sh + +test:copyright: + extends: [.test_rules] + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + image: ${UTILITY_IMAGE}:${CI_PIPELINE_ID} + needs: [test:build_image] + script: + - git fetch origin main + - bash tools/copyright.sh + +# Override from template +secret_detection: + rules: + - when: never + +# Inherit and modify template +test:secret_detection: + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + extends: ['.secret-analyzer'] + variables: + GIT_DEPTH: 0 + SECRET_DETECTION_LOG_OPTIONS: ${CI_MERGE_REQUEST_DIFF_BASE_SHA}..${CI_COMMIT_SHA} + allow_failure: true + rules: + - if: $CI_PIPELINE_SOURCE == "merge_request_event" + - when: never + script: + - apk add jq + - /analyzer run + - | + if [[ $(cat gl-secret-detection-report.json | jq '.vulnerabilities | length > 0') == true ]]; then + echo "Atleast one vulnerability has been found" + cat gl-secret-detection-report.json | jq '.' + exit 1 + fi + +test:pypi_build_wheel: + extends: [.test_rules] + image: + name: quay.io/pypa/manylinux_2_28_x86_64 + entrypoint: [''] + services: + - name: docker:24.0.5-dind + variables: + HEALTHCHECK_TCP_PORT: '2376' + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/builder-small + - team/megatron + variables: + PUBLISH_DRYRUN: 'yes' + PY_ENV: pytorch_24.10 + script: + - echo $PUBLISH_DRYRUN + - > + if [ "$PUBLISH_DRYRUN" = "yes" ]; then + PRE_RELEASE=$(sed -n "s/.*PRE_RELEASE = '\(.*\)'/\1/p" megatron/core/package_info.py) + sed -i "/^PRE_RELEASE/c\PRE_RELEASE = '${PRE_RELEASE}.dev$((RANDOM % 900000 + 100000))'" megatron/core/package_info.py + fi + + + - /opt/python/cp310-cp310/bin/python -m build + - /opt/python/cp311-cp311/bin/python -m build + - auditwheel repair dist/*.whl + - rm -rf dist/*.whl + + - pushd megatron/core + - EXPECTED_RELEASE_NUMBER=$(/opt/python/cp311-cp311/bin/python -c "import package_info; print(package_info.__version__)") + - popd + - echo "EXPECTED_RELEASE_NUMBER=$EXPECTED_RELEASE_NUMBER" | tee -a build.env + artifacts: + paths: + - megatron/core/package_info.py + - wheelhouse/ + - dist/ + reports: + dotenv: build.env + +test:pypi_test_wheel: + extends: [.test_rules] + image: + name: python:3.11 + entrypoint: [''] + needs: [test:pypi_build_wheel] + services: + - name: docker:24.0.5-dind + variables: + HEALTHCHECK_TCP_PORT: '2376' + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/builder-small + - team/megatron + variables: + PUBLISH_DRYRUN: 'yes' + script: + - rm -rf megatron + - pip install wheelhouse/*cp311*.whl + + - RELEASE_NUMBER=$(python -c "from megatron import core; print(core.__version__)") + - > + echo "$EXPECTED_RELEASE_NUMBER" == "$RELEASE_NUMBER" + + + - test "$EXPECTED_RELEASE_NUMBER" == "$RELEASE_NUMBER" + - echo "RELEASE_NUMBER=$EXPECTED_RELEASE_NUMBER" | tee -a build.env + artifacts: + reports: + dotenv: build.env + paths: + - wheelhouse/ + - dist/ + +test:pypi_push_wheel: + extends: [.test_rules] + image: python:3.11 + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + needs: [test:pypi_test_wheel] + variables: + PUBLISH_DRYRUN: 'yes' + timeout: 3m + script: + - > + if [ "$PUBLISH_DRYRUN" = "yes" ]; then + REPOSITORY=testpypi + export TWINE_USERNAME=$TWINE_TEST_USERNAME + export TWINE_PASSWORT=$TWINE_TEST_PASSWORD + else + REPOSITORY=pypi + export TWINE_USERNAME=$TWINE_PROD_USERNAME + export TWINE_PASSWORT=$TWINE_PROD_PASSWORD + fi + + - ls -al dist/ + - ls -al wheelhouse/ + - pip install twine + - > + for i in 1 2 3 4 5; do + twine upload --verbose -u $TWINE_USERNAME -p $TWINE_PASSWORT --repository $REPOSITORY wheelhouse/* dist/* && break || sleep $(( 60*2**i )); + done + + + rules: + - if: $UNIT_TEST == 'yes' && $CI_PIPELINE_SOURCE == 'merge_request_event' && $CI_MERGE_REQUEST_TARGET_BRANCH_PROTECTED != "true" + allow_failure: true + when: on_success + - when: on_success + allow_failure: true + +test:gh_release: + extends: [.test_rules] + needs: [test:pypi_test_wheel] + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + image: badouralix/curl-jq + variables: + PUBLISH_DRYRUN: 'yes' + script: + - NAME="NVIDIA Megatron Core $RELEASE_NUMBER" + - IS_PRERELEASE=$([[ "$RELEASE_NUMBER" == *rc* ]] && echo "true" || echo "false") + - > + if [[ "$IS_PRERELEASE" == "true" ]]; then + DATE=$(date +"%Y-%m-%d") + CHANGELOG="Prerelease: $NAME ($DATE)" + else + CHANGELOG=$(awk '/^## '"$NAME"'/{flag=1; next} /^## /{flag=0} flag' CHANGELOG.md) + CHANGELOG=$(echo "$CHANGELOG" | sed '/./!d') + fi + - > + PAYLOAD=$(jq -nc \ + --arg TAG_NAME "v${RELEASE_NUMBER}" \ + --arg CI_COMMIT_SHA "$CI_COMMIT_SHA" \ + --arg NAME "$NAME" \ + --arg BODY "$CHANGELOG" \ + --argjson PRERELEASE "$IS_PRERELEASE" \ + '{ + "tag_name": $TAG_NAME, + "target_commitish": $CI_COMMIT_SHA, + "name": $NAME, + "body": $BODY, + "draft": false, + "prerelease": $PRERELEASE, + "generate_release_notes": false + }' + ) + echo -E "$PAYLOAD" > payload.txt + - cat payload.txt + - > + CMD=$(echo -E 'curl -L \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer '"$GH_TOKEN"'" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/NVIDIA/Megatron-LM/releases \ + -d @payload.txt + ') + + - > + if [[ "$PUBLISH_DRYRUN" == "yes" ]]; then + echo -E "$CMD" + else + eval "$CMD" + fi + + +test:notify_release: + needs: [test:pypi_test_wheel, test:pypi_push_wheel, test:gh_release] + extends: [.test_rules] + image: badouralix/curl-jq + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + variables: + PUBLISH_DRYRUN: 'yes' + script: + - URL="https://github.com/NVIDIA/Megatron-LM/releases/tag/core_r$RELEASE_NUMBER" + - > + MESSAGE='{ + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "Releasebot 🤖: Megatron-Core released <'$URL'|core_r'"$RELEASE_NUMBER"'> 🚀" + } + } + ] + }' + + + - echo "$MESSAGE" + - > + CMD=$(echo curl \ + -X POST \ + -H "Content-type: application/json" \ + --data "$MESSAGE" ${MCORE_NOTIFICATION_HOOK_MAIN} + ) + + if [[ "$PUBLISH_DRYRUN" == "yes" ]]; then + echo "$CMD" + else + eval "$CMD" + fi + diff --git a/.gitlab/stages/02.functional-tests.yml b/.gitlab/stages/02.functional-tests.yml new file mode 100644 index 0000000000..ddf3fd85df --- /dev/null +++ b/.gitlab/stages/02.functional-tests.yml @@ -0,0 +1,187 @@ +.functional_tests_rules: + stage: functional_tests + rules: + - if: $FUNCTIONAL_TEST == "yes" && ($CI_PIPELINE_SOURCE == 'merge_request_event' && $CI_MERGE_REQUEST_TARGET_BRANCH_PROTECTED != "true") + allow_failure: true + - if: $FUNCTIONAL_TEST == "yes" + - when: never + +default: + id_tokens: + VAULT_JWT_TOKEN: + aud: https://stg.vault.nvidia.com + +include: + - project: dl/jet/gitlab-templates + ref: main + file: downstreams.yml + +functional:configure: + needs: + - test:build_image + - job: test:unit_tests_pyt(DEV)_mcore(latest) + optional: true + - job: test:unit_tests_pyt(LTS)_mcore(latest) + optional: true + extends: [.functional_tests_rules] + image: ${UTILITY_IMAGE}:${CI_PIPELINE_ID} + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + before_script: + - git rm -r tests/test_utils/local_recipes || true + - git submodule add --force https://gitlab-ci-token:${CI_JOB_TOKEN}@${GITLAB_ENDPOINT}/ADLR/megatron-lm-convergence-tests.git tests/test_utils/local_recipes + - ls tests/test_utils/local_recipes + script: + - set -x + - | + A100_CLUSTER=$([[ "$FUNCTIONAL_TEST_CLUSTER_A100" != "" ]] && echo $FUNCTIONAL_TEST_CLUSTER_A100 || echo $DEFAULT_A100_CLUSTER) + H100_CLUSTER=$([[ "$FUNCTIONAL_TEST_CLUSTER_H100" != "" ]] && echo $FUNCTIONAL_TEST_CLUSTER_H100 || echo $DEFAULT_H100_CLUSTER) + - | + RECORD_CHECKPOINTS=$([[ "$CI_MERGE_REQUEST_LABELS" == *"Record checkpoints"* ]] && echo "true" || echo "false") + - | + if [[ "$FUNCTIONAL_TEST_SCOPE" == "release" || "$FUNCTIONAL_TEST_SCOPE" == "pre-release" ]]; then + FUNCTIONAL_TEST_NAME=$(eval echo $FUNCTIONAL_TEST_NAME) + RELEASE_ARGS=( + "--run-name" + $FUNCTIONAL_TEST_NAME + "--wandb-experiment" + $(echo $FUNCTIONAL_TEST_NAME | tr '/' '-') + ) + else + RELEASE_ARGS=() + fi + - | + ARGS=( + "--scope $FUNCTIONAL_TEST_SCOPE" + "--n-repeat $FUNCTIONAL_TEST_REPEAT" + "--time-limit $FUNCTIONAL_TEST_TIME_LIMIT" + "--test-cases $FUNCTIONAL_TEST_CASES" + "--a100-cluster $A100_CLUSTER" + "--h100-cluster $H100_CLUSTER" + "--container-image ${UTILITY_IMAGE}" + "--container-tag ${CI_PIPELINE_ID}" + "--dependent-job functional:configure" + "--record-checkpoints ${RECORD_CHECKPOINTS}" + ) + - | + export PYTHONPATH=$(pwd) + python tests/test_utils/python_scripts/generate_jet_trigger_job.py \ + ${ARGS[@]} \ + --environment dev \ + --output-path "functional-test-job-dev.yaml" \ + ${RELEASE_ARGS[@]} + - | + export PYTHONPATH=$(pwd) + python tests/test_utils/python_scripts/generate_jet_trigger_job.py \ + ${ARGS[@]} \ + --environment lts \ + --output-path "functional-test-job-lts.yaml" \ + ${RELEASE_ARGS[@]} + artifacts: + paths: + - functional-test-job-lts.yaml + - functional-test-job-dev.yaml + - tests/test_utils/local_recipes + +.run: + stage: functional_tests + needs: [functional:configure] + extends: [.functional_tests_rules] + trigger: + include: + - artifact: functional-test-job-$ENVIRONMENT.yaml + job: functional:configure + strategy: depend + variables: + RO_API_TOKEN: $PAT + CONTAINER_TAG: $CI_PIPELINE_ID + CI_MCORE_LTS_IMAGE: $CI_MCORE_LTS_IMAGE + GITLAB_ENDPOINT: $GITLAB_ENDPOINT + PARENT_PIPELINE_ID: $CI_PIPELINE_ID + inherit: + variables: true + +functional:run_lts: + extends: [.run] + variables: + ENVIRONMENT: lts + +functional:run_dev: + extends: [.run] + variables: + ENVIRONMENT: dev + +functional:run_nemo: + extends: [.functional_tests_rules] + trigger: + project: 'dl/joc/nemo-ci' + branch: main-mirror + strategy: depend + inherit: + variables: true + variables: + MCORE_COMMIT: $CI_COMMIT_SHA + TEST_LLM_MODULE: 'True' + TEST_ALIGNER_MODULE: 'False' + TEST_DATA_CURATOR_MODULE: 'False' + TESTS_TO_RUN_ON_THIS_COMMIT: nightly + rules: + - if: $FUNCTIONAL_TEST == "yes" + when: manual + allow_failure: true + - when: never + +functional:notify: + extends: [.functional_tests_rules] + image: badouralix/curl-jq + needs: + - functional:run_lts + - functional:run_dev + tags: + - mcore-docker-node-small + variables: + WEBHOOK_URL: ${MCORE_NOTIFICATION_HOOK} + RO_API_TOKEN: ${PROJECT_ACCESS_TOKEN_MCORE} + CONTEXT: $FUNCTIONAL_TEST_SCOPE + script: + - apk add bash + - apk add --update coreutils + - env + - export WEBHOOK_URL=${MCORE_NOTIFICATION_HOOK} + - export RO_API_TOKEN=${PROJECT_ACCESS_TOKEN_MCORE} + - export GITLAB_ENDPOINT + - export CONTEXT=$FUNCTIONAL_TEST_SCOPE + - export DATE=$(date +"%Y-%m-%d") + - bash tests/test_utils/shell_scripts/notify.sh ${CI_PIPELINE_ID} "functional:run_" + artifacts: + when: always + paths: + - scripts + rules: + - if: $CI_PIPELINE_SOURCE == "schedule" && $FUNCTIONAL_TEST == "yes" + when: always + - when: never + +functional:download_golden_values: + extends: [.functional_tests_rules] + image: ${UTILITY_IMAGE}:${CI_PIPELINE_ID} + tags: + - mcore-docker-node-small + script: + - env + - export RO_API_TOKEN=${PROJECT_ACCESS_TOKEN_MCORE} + - export GITLAB_ENDPOINT + - python tests/test_utils/python_scripts/download_golden_values.py --pipeline-id ${CI_PIPELINE_ID} + artifacts: + paths: + - tests/ + rules: + - if: $FUNCTIONAL_TEST == "yes" + when: manual + allow_failure: true + - when: never diff --git a/.gitlab/stages/03.publish.yml b/.gitlab/stages/03.publish.yml new file mode 100644 index 0000000000..48ea9bfbfe --- /dev/null +++ b/.gitlab/stages/03.publish.yml @@ -0,0 +1,126 @@ +.publish_common_freeze: + stage: publish + rules: + - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH && $PUBLISH == "yes" && $PUBLISH_SCOPE == "code-freeze" + when: manual + - when: never + +.publish_common_release: + stage: publish + rules: + - if: $CI_COMMIT_BRANCH =~ /^core_r/ && $PUBLISH == "yes" && $PUBLISH_SCOPE == "release" + when: manual + - if: $PUBLISH == "yes" && $PUBLISH_SCOPE == "release" + when: manual + variables: + PUBLISH_DRYRUN: 'yes' + - when: never + +publish:release_branch: + extends: [.publish_common_freeze] + image: ${CI_MCORE_LTS_IMAGE}:${CI_PIPELINE_ID} + needs: [test:build_image] + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + variables: + GIT_STRATEGY: 'none' + script: + - git fetch origin $CI_DEFAULT_BRANCH + - git config --global user.email "mcore-bot@nvidia.com" + - git config --global user.name "Mcore Bot" + - git remote set-url origin "https://gitlab-ci-token:${PAT}@${GITLAB_ENDPOINT}/$CI_PROJECT_NAMESPACE/megatron-lm.git" + - sed -i "/^PRE_RELEASE/c\PRE_RELEASE = ''" megatron/core/package_info.py + - VERSION=$(python -c "from megatron import core; print(core.__version__)") + - RELEASE_BRANCH=core_r$VERSION + - git switch --force-create $RELEASE_BRANCH origin/$CI_DEFAULT_BRANCH + - | + MESSAGE='{ + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "Releasebot 🤖: Megatron Core has been frozen 🎉 to branch `'"$RELEASE_BRANCH"'`" + } + } + ] + }' + - > + curl -X POST -H "Content-type: application/json" --data "$MESSAGE" ${MCORE_NOTIFICATION_HOOK_MAIN} + + + - git switch --force-create bot/chore/bump-version + - git add megatron/core/package_info.py + - > + git commit -m "chore: adjust version version" + + + - git push -u origin bot/chore/bump-version + - > + curl \ + --header "PRIVATE-TOKEN: $PAT" \ + --url https://${GITLAB_ENDPOINT}/api/v4/projects/${CI_PROJECT_ID}/merge_requests \ + -d "source_branch=bot/chore/bump-version" \ + -d "target_branch=$RELEASE_BRANCH" \ + -d "title=chore: Fix version of \`$RELEASE_BRANCH\`" \ + -d "description=[🤖]: Hi @okoenig 👋,

we've adjusted the version number of \`$RELEASE_BRANCH\` for you! 🚀

Please review and approve this cherry pick by your convenience\!" + +publish:pypi_build_wheel: + extends: [test:pypi_build_wheel, .publish_common_release] + dependencies: [] + variables: + PUBLISH_DRYRUN: 'no' + +publish:pypi_test_wheel: + extends: [test:pypi_test_wheel, .publish_common_release] + needs: [publish:pypi_build_wheel] + variables: + PUBLISH_DRYRUN: 'no' + +publish:pypi_push_wheel: + extends: [test:pypi_push_wheel, .publish_common_release] + needs: [publish:pypi_test_wheel] + dependencies: [publish:pypi_test_wheel] + variables: + PUBLISH_DRYRUN: 'no' + +publish:gh_release: + extends: [test:gh_release, .publish_common_release] + dependencies: [publish:pypi_test_wheel] + needs: [publish:pypi_test_wheel] + variables: + PUBLISH_DRYRUN: 'no' + +publish:notify_release: + needs: [publish:pypi_push_wheel, publish:gh_release] + extends: [test:notify_release, .publish_common_release] + variables: + PUBLISH_DRYRUN: 'no' + +publish:docs: + extends: [.publish_common_release] + image: ${UTILITY_IMAGE}:${CI_PIPELINE_ID} + tags: + - arch/amd64 + - env/prod + - origin/jet-fleet + - owner/jet-core + - purpose/utility + - team/megatron + script: + - cd .. + - rm -rf documentation && git clone https://gitlab-ci-token:${PROJECT_ACCESS_TOKEN_MCORE}@${GITLAB_ENDPOINT}/nemo-megatron-core-tme/documentation.git + - cd documentation/megatron-lm + git fetch origin '+refs/merge-requests/*:refs/remotes/merge-requests/*' + - git fetch origin $CI_COMMIT_SHA + - git checkout $CI_COMMIT_SHA + - cd .. + - git add megatron-lm + - > + git commit -m 'feat: Bump mcore' + - git push diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000000..865f483849 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,13 @@ +[MAIN] +ignore-paths=tests +max-line-length=100 + +[MESSAGES CONTROL] +disable=all + +enable=C0115,C0116,W0611,C0301,E0606 +# C0115: missing-class-docstring +# C0116: missing-function-docstring +# W0611: unused-import +# C0301: line-too-long +# E0606: possibly-used-before-assignment \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000..01e3748724 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,132 @@ +# Changelog + +## NVIDIA Megatron Core 0.10.0 + +- Adding MLA to MCore +- Enable FP8 for GroupedMLP +- MoE Parallel Folding +- Enhance MoE Architecture: Support MoE Layer Frequency Patterns and Configurable MoE FFN Hidden Size +- Multimodal: NVLM training and evaluation support in MCore +- Mamba Hybrid + - Increase performance and reduce memory footprint of Triton language/compiler distributed caching + - Add more unit testing and fix bugs + +## NVIDIA Megatron Core 0.9.0 + +- Uneven pipeline parallelism + - Enable pipeline parallelism where first and last ranks have fewer transformer layers than the intermediate ranks +- Per layer CUDAGraph support for GPT training with Transformer Engine modules +- Enable different TP sizes for the vision encoder +- Enable pipeline parallelism for T5 & Llava models +- Support multi-tile multi-image input in Llava models +- MoE + - FP8 support + - Runtime upcycling support + - Dispatcher implementation optimizations + - Shared expert support with overlapping optimizations + - Qwen Model support +- Known Issues + - When using sequence parallel, during the transformer block forward pass, dropout is not using the appropriate rng context. + +## NVIDIA Megatron Core 0.8.0 + +- Multimodal + - Added initial support for training vision language models using the LLaVA architecture + - Added initial support for inference with multimodal inputs + - End-to-end multimodal example from data collection to training to evaluation is provided in examples/multimodal +- MoE + - Context Parallel support. + - Distributed checkpoint support for grouped GEMM. +- Mamba + +## NVIDIA Megatron Core 0.7.0 + +- MoE + - Token drop support + - Several efficiency optimizations + - Improved model parallelism + - Memory optimizations +- Distributed checkpointing + - Enabled for Retro + - Asynchronous checkpoint saving +- Several minor bug fixes, speed improvements, and memory optimizations + +## NVIDIA Megatron Core 0.6.0 + +- MoE (Mixture of Experts) + - Performance optimization + - Communication optimization for multi GPU and Single GPU + - 23% improvement (323 TFLOPS/GPU) over MCore 0.5.0 on Mixtral with Hopper BF16 + - GroupedMLP enhancement for Hopper + - DP Overlapping. Support overlapping computation with gradient reduction and parameter gathering. + - All-to-All based Token Dispatcher + - Layer-wise logging for load balancing loss. + - Improved expert parallel support including distributed optimizer. +- Distributed optimizer +- RETRO + - Data processing +- BERT + - Distributed checkpointing +- Dist checkpointing + - PyTorch native distributed backend + - Improved saving/loading speed +- TensorRT-LLM Export + - Integration with TensorRT Model Optimizer Post-training quantization (PTQ) + - Text generation driver to perform PTQ in Megatron-LM + - Llama2 and Nemotron3-8b examples to use TensorRT-LLM unified build API to build engine after training. +- Several minor enhancements, bug fixes, and documentation updates + +## NVIDIA Megatron Core 0.5.0 + +### Key Features and Enhancements + +Megatron core documentation is now [live!](https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start) + +### Model Features + +- MoE (Mixture of Experts) + - Support for Z-loss, Load balancing and Sinkhorn + - Layer and communications refactor + - Richer parallelism mappings and EP can be combined with other model parallel techniques for larger MoE variants, e.g. EP + TP + DP + SP + PP + - Token dropless architecture with Top-K routing + - Performance optimization with with GroupedGEMM when number of local experts is > 1 + - Distributed checkpointing +- Interleaved rotary embedding + +### Datasets + +- Masked WordPiece datasets for BERT and T5 +- Raw and mock datasets + +### Parallelism + +### Performance + +- Activation offloading to CPU +- Rope and Swiglu fusion +- Sliding window attention (via Transformer Engine) + +### General Improvements + +- Timers + +## NVIDIA Megatron Core 0.4.0 + +### Key Features and Enhancements + +#### Models + +- BERT +- RETRO +- T5 + +#### Parallelism + +- Mixture of Experts support for GPT +- Model parallel efficient Distributed Data Parallel (DDP) +- Context Parallel (2D Tensor Parallel) support + +#### Datasets + +- GPT Dataset +- Blended Dataset diff --git a/CODEOWNERS b/CODEOWNERS index cf30f9c148..e89c62b06e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,6 +1,49 @@ -[MCORE][3] -megatron/core/ @shanmugamr @maanug @jcasper @eharper +[Core-ADLR] @mcore-reviewers/core-adlr +megatron/core/ -[TESTS] -tests/ @shanmugamr @maanug +[Core-NeMo] @mcore-reviewers/core-nemo +megatron/core/ +^[Core-MLPerf] @mcore-reviewers/mlperf +megatron/core/ + +[MoE-ADLR] @mcore-reviewers/moe-adlr +megatron/core/transformer/moe/ + +[MoE-Moe] @mcore-reviewers/moe-moe +megatron/core/transformer/moe/ + +[Datasets] @mcore-reviewers/datasets +megatron/core/datasets/ + +[BERT] @mcore-reviewers/bert +megatron/core/models/bert/ + +[GPT] @mcore-reviewers/gpt +megatron/core/models/gpt/ + +[Retro] @mcore-reviewers/retro +megatron/core/models/retro/ + +[Distributed Checkpointing] @mcore-reviewers/dist-checkpointing +megatron/core/dist_checkpointing/ + +[Distributed Optimizer] @mcore-reviewers/dist-optimizer +megatron/core/optimizer/distrib_optimizer/ + +[Inference] @mcore-reviewers/inference +megatron/core/inference/ + +^[Quantization and Inference (QAT)] @mcore-reviewers/quantization-and-inference +megatron/core/inference/ + +; [Context Parallelism] @mcore-reviewers/context-parallelism +; + +[CI] @mcore-reviewers/ci +.gitlab/ +.github/ +.gitlab-ci.yml +Dockerfile.ci.lts +Dockerfile.ci.dev +tests/ diff --git a/Dockerfile.ci.dev b/Dockerfile.ci.dev new file mode 100644 index 0000000000..4516e33154 --- /dev/null +++ b/Dockerfile.ci.dev @@ -0,0 +1,84 @@ +# syntax=docker/dockerfile:1.3-labs + +ARG FROM_IMAGE_NAME +FROM $FROM_IMAGE_NAME as build_causal_conv1d +WORKDIR /opt +RUN CAUSAL_CONV1D_FORCE_BUILD=TRUE pip3 wheel -v git+https://github.com/Dao-AILab/causal-conv1d.git@v1.2.2.post1 + +FROM $FROM_IMAGE_NAME as build_grouped_gemm +WORKDIR /opt +RUN pip3 wheel -v git+https://github.com/fanshiqing/grouped_gemm@v1.1.2 + +FROM $FROM_IMAGE_NAME as build_mamba_ssm +WORKDIR /opt +RUN git clone https://github.com/state-spaces/mamba.git && \ + cd mamba && \ + git checkout v2.2.0 && \ + sed -i "/triton/d" setup.py && \ + MAMBA_FORCE_BUILD=TRUE pip3 wheel -v . + +FROM $FROM_IMAGE_NAME as main +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && \ + apt-get install -y --no-install-recommends gettext python3-venv && \ + apt-get clean && \ + python -m venv /opt/jet && \ + wget https://github.com/mikefarah/yq/releases/download/v4.44.1/yq_linux_amd64 -O /usr/local/bin/yq && \ + chmod a+x /usr/local/bin/yq + +COPY --from=build_causal_conv1d /opt/causal_conv1d-*.whl ./ +COPY --from=build_grouped_gemm /opt/grouped_gemm-*.whl ./ +COPY --from=build_mamba_ssm /opt/mamba/mamba_ssm-*.whl ./ + +RUN \ + --mount=type=bind,source=requirements,target=requirements \ + --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ + --mount=type=bind,source=setup.py,target=setup.py \ + --mount=type=bind,source=megatron/core/package_info.py,target=megatron/core/package_info.py \ + --mount=type=bind,source=megatron/core/README.md,target=megatron/core/README.md \ + --mount=type=bind,source=megatron/core/requirements.txt,target=megatron/core/requirements.txt \ + --mount=type=bind,source=megatron/core/__init__.py,target=megatron/core/__init__.py <<"EOF" bash -ex + +pip install causal_conv1d-*.whl mamba_ssm-*.whl grouped_gemm-*.whl +PY_ENV=pytorch_24.10 pip install . +EOF + +# Since megatron does not have any dependencies (and isn't a dependency to any other package), we can install it separately to make everything a bit quicker +ARG MCORE_REPO +ARG MCORE_REF +ARG MCORE_BACKWARDS_REF +RUN <<"EOF" bash -exu +# Checkout latest +cd /opt +rm -rf /opt/megatron-lm; mkdir megatron-lm; cd megatron-lm +git init +git remote add origin ${MCORE_REPO} +git fetch origin '+refs/merge-requests/*:refs/remotes/merge-requests/*' +git fetch origin $MCORE_REF +git checkout $MCORE_REF + +# Checkout backwards-ref +cd /opt +rm -rf /opt/megatron-lm-legacy; mkdir megatron-lm-legacy; cd megatron-lm-legacy +git init +git remote add origin ${MCORE_REPO} +git fetch origin $MCORE_BACKWARDS_REF +git checkout $MCORE_BACKWARDS_REF +rm -rf megatron; cp -a /opt/megatron-lm/megatron ./ +EOF + +RUN PY_ENV=pytorch_24.10 pip install -e /opt/megatron-lm +ENV PYTHONPATH="/opt/megatron-lm:$PYTHONPATH" + +##### For NVIDIANS only ##### +FROM main as jet +ARG CACHEBUST=0 +RUN --mount=type=secret,id=JET_INDEX_URLS \ + --mount=type=secret,id=LOGGER_INDEX_URL \ + LOGGER_INDEX_URL=$(cat /run/secrets/LOGGER_INDEX_URL) && \ + JET_INDEX_URLS=$(cat /run/secrets/JET_INDEX_URLS) && \ + pip install "jet-client~=2.0" jet-api --upgrade $JET_INDEX_URLS && \ + pip install "one-logger" --upgrade $LOGGER_INDEX_URL +ENV PATH="$PATH:/opt/jet/bin" +### \ No newline at end of file diff --git a/Dockerfile.ci.lts b/Dockerfile.ci.lts new file mode 100644 index 0000000000..327934a457 --- /dev/null +++ b/Dockerfile.ci.lts @@ -0,0 +1,81 @@ +# syntax=docker/dockerfile:1.3-labs + +ARG FROM_IMAGE_NAME +FROM $FROM_IMAGE_NAME as build_causal_conv1d +WORKDIR /opt +RUN CAUSAL_CONV1D_FORCE_BUILD=TRUE pip3 wheel -v git+https://github.com/Dao-AILab/causal-conv1d.git@v1.2.2.post1 + +FROM $FROM_IMAGE_NAME as build_grouped_gemm +WORKDIR /opt +RUN pip3 wheel -v git+https://github.com/fanshiqing/grouped_gemm@v1.1.2 + +FROM $FROM_IMAGE_NAME as build_mamba_ssm +WORKDIR /opt +RUN MAMBA_FORCE_BUILD=TRUE pip3 wheel -v git+https://github.com/state-spaces/mamba.git@v2.0.3 + +ARG FROM_IMAGE_NAME +FROM $FROM_IMAGE_NAME as main +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && \ + apt-get install -y --no-install-recommends gettext python3-venv && \ + apt-get clean && \ + python -m venv /opt/jet && \ + wget https://github.com/mikefarah/yq/releases/download/v4.44.1/yq_linux_amd64 -O /usr/local/bin/yq && \ + chmod a+x /usr/local/bin/yq + +COPY --from=build_causal_conv1d /opt/causal_conv1d-*.whl ./ +COPY --from=build_grouped_gemm /opt/grouped_gemm-*.whl ./ +COPY --from=build_mamba_ssm /opt/mamba_ssm-*.whl ./ + +RUN \ + --mount=type=bind,source=requirements,target=requirements \ + --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ + --mount=type=bind,source=setup.py,target=setup.py \ + --mount=type=bind,source=megatron/core/package_info.py,target=megatron/core/package_info.py \ + --mount=type=bind,source=megatron/core/README.md,target=megatron/core/README.md \ + --mount=type=bind,source=megatron/core/requirements.txt,target=megatron/core/requirements.txt \ + --mount=type=bind,source=megatron/core/__init__.py,target=megatron/core/__init__.py <<"EOF" bash -ex + +pip install causal_conv1d-*.whl mamba_ssm-*.whl grouped_gemm-*.whl +PY_ENV=pytorch_24.01 pip install . +EOF + +# Since megatron does not have any dependencies (and isn't a dependency to any other package), we can install it separately to make everything a bit quicker +ARG MCORE_REPO +ARG MCORE_REF +ARG MCORE_BACKWARDS_REF +RUN <<"EOF" bash -exu +# Checkout latest +cd /opt +rm -rf /opt/megatron-lm; mkdir megatron-lm; cd megatron-lm +git init +git remote add origin ${MCORE_REPO} +git fetch origin '+refs/merge-requests/*:refs/remotes/merge-requests/*' +git fetch origin $MCORE_REF +git checkout $MCORE_REF + +# Checkout backwards-ref +cd /opt +rm -rf /opt/megatron-lm-legacy; mkdir megatron-lm-legacy; cd megatron-lm-legacy +git init +git remote add origin ${MCORE_REPO} +git fetch origin $MCORE_BACKWARDS_REF +git checkout $MCORE_BACKWARDS_REF +rm -rf megatron; cp -a /opt/megatron-lm/megatron ./ +EOF + +RUN PY_ENV=pytorch_24.01 pip install -e /opt/megatron-lm +ENV PYTHONPATH="/opt/megatron-lm:$PYTHONPATH" + +##### For NVIDIANS only ##### +FROM main as jet +ARG CACHEBUST=0 +RUN --mount=type=secret,id=JET_INDEX_URLS \ + --mount=type=secret,id=LOGGER_INDEX_URL \ + LOGGER_INDEX_URL=$(cat /run/secrets/LOGGER_INDEX_URL) && \ + JET_INDEX_URLS=$(cat /run/secrets/JET_INDEX_URLS) && \ + pip install "jet-client~=2.0" jet-api --upgrade $JET_INDEX_URLS && \ + pip install "one-logger" --upgrade $LOGGER_INDEX_URL +ENV PATH="$PATH:/opt/jet/bin" +### \ No newline at end of file diff --git a/Dockerfile.linting b/Dockerfile.linting new file mode 100644 index 0000000000..9034ee8412 --- /dev/null +++ b/Dockerfile.linting @@ -0,0 +1,33 @@ +# syntax=docker/dockerfile:experimental + +ARG FROM_IMAGE_NAME +FROM $FROM_IMAGE_NAME as main +ENV DEBIAN_FRONTEND=noninteractive + +RUN sed -i -e 's/^APT/# APT/' -e 's/^DPkg/# DPkg/' \ + /etc/apt/apt.conf.d/docker-clean + +RUN apt-get update && \ + apt-get install -y python3-venv && \ + apt-get clean && \ + python -m venv /opt/jet + +RUN pip3 install --no-cache-dir \ + black==24.4.2 \ + isort==5.13.2 \ + flake8==7.1.0 \ + pylint==3.2.6 \ + mypy + +COPY . /opt/megatron-lm + +WORKDIR /opt/megatron-lm + +##### For NVIDIANS only ##### +FROM main as jet +ARG CACHEBUST=0 +RUN --mount=type=secret,id=JET_INDEX_URLS \ + JET_INDEX_URLS=$(cat /run/secrets/JET_INDEX_URLS) && \ + pip install "jet-client~=2.0" jet-api --upgrade $JET_INDEX_URLS +ENV PATH="$PATH:/opt/jet/bin" +### \ No newline at end of file diff --git a/LICENSE b/LICENSE index 72c23ae15e..7ad15e594b 100644 --- a/LICENSE +++ b/LICENSE @@ -29,13 +29,15 @@ The following applies to all files unless otherwise noted: -- This repository also contains code from Hugging Face Inc., Google Research, -Facebook (from their Fairseq and Dino projects), Microsoft(from their -Swin-Transformer project)and Philip Popien. Files from these -organizations have notices at the top of each file. Below are -licenses used in those files, as indicated. +Facebook (from their Fairseq, Dino, and ParlAI projects), Microsoft (from their +Swin-Transformer project), Philip Popien, the Mamba project (Tri Dao and +Albert Gu), and the Triton language and compiler project (Philippe Tillet and +OpenAI). Files from these organizations have notices at the top of each file. +Below are licenses used in those files, as indicated. -------------- LICENSE FOR Facebook, huggingface and Google Research code -------------- +-------------------------------------------------------------------------------------- +-- LICENSE FOR Facebook, huggingface, Google Research, LLaVA, Mamba, and vLLM code -- Apache License @@ -240,12 +242,16 @@ licenses used in those files, as indicated. See the License for the specific language governing permissions and limitations under the License. -------------- LICENSE FOR Facebook Fairseq code -------------- +-------------------------------------------------------------------------------- +LICENSE FOR +Facebook, Inc. and its affiliates, +Meta Platforms, Inc. and its affiliates, +Microsoft Corporation, +OpenGVLab/InternVL, and +Triton language and compiler. MIT License -Copyright (c) Facebook, Inc. and its affiliates. - Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights @@ -264,28 +270,3 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------- LICENSE FOR Mircrosoft Swin transformer code -------------- - -MIT License - -Copyright (c) Microsoft Corporation. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE - - diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000..dbed9c4061 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include megatron/core/requirements.txt +include megatron/core/README.md +recursive-include requirements * diff --git a/README.md b/README.md index dfe29ffb0b..f98bd5281e 100644 --- a/README.md +++ b/README.md @@ -1,64 +1,94 @@ -Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf), [2](https://arxiv.org/pdf/2104.04473.pdf), and [3](https://arxiv.org/pdf/2205.05198)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel ([tensor](https://arxiv.org/pdf/1909.08053.pdf), [sequence](https://arxiv.org/pdf/2205.05198), and [pipeline](https://arxiv.org/pdf/2104.04473.pdf)), and multi-node pre-training of transformer based models such as [GPT](https://arxiv.org/abs/2005.14165), [BERT](https://arxiv.org/pdf/1810.04805.pdf), and [T5](https://arxiv.org/abs/1910.10683) using mixed precision. +
+ +Megatron-LM & Megatron-Core +=========================== +

GPU optimized techniques for training transformer models at-scale

+ +[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html) +[![version](https://img.shields.io/badge/release-0.5.0-green)](./setup.py) +[![license](https://img.shields.io/badge/license-OpenBSD-blue)](./LICENSE) + +
+ +# Latest News + +- **[2024/7]** Megatron-Core v0.7 improves scalability and training resiliency and adds support for multimodal training ([blog](https://developer.nvidia.com/blog/train-generative-ai-models-more-efficiently-with-new-nvidia-megatron-core-functionalities/)). +- **[2024/6]** Megatron-Core added supports for Mamba-based models. Check out our paper [An Empirical Study of Mamba-based Language Models](https://arxiv.org/pdf/2406.07887) and [code example](https://github.com/NVIDIA/Megatron-LM/tree/ssm/examples/mamba). +- **[2024/1 Announcement]** NVIDIA has released the core capabilities in **Megatron-LM** into [**Megatron-Core**](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core) in this repository. Megatron-Core expands upon Megatron-LM's GPU-optimized techniques with more cutting-edge innovations on system-level optimizations, featuring composable and modular APIs. Explore the [Megatron-Core intro](#megatron-core) for more details. + +# Table of Contents + +- [Megatron-LM \& Megatron-Core](#megatron-lm--megatron-core) +- [Latest News](#latest-news) +- [Table of Contents](#table-of-contents) +- [Megatron Overview](#megatron-overview) + - [Megatron-LM](#megatron-lm) + - [Megatron-Core](#megatron-core) +- [Training Speed and Scalability](#training-speed-and-scalability) +- [Setup](#setup) + - [Downloading Checkpoints](#downloading-checkpoints) +- [Usage](#usage) +- [Training](#training) + - [Data Preprocessing](#data-preprocessing) + - [BERT Pretraining](#bert-pretraining) + - [GPT Pretraining](#gpt-pretraining) + - [T5 Pretraining](#t5-pretraining) + - [Distributed Pretraining](#distributed-pretraining) + - [Activation Checkpointing and Recomputation](#activation-checkpointing-and-recomputation) + - [Distributed Optimizer](#distributed-optimizer) + - [FlashAttention](#flashattention) + - [GPT-3 Example](#gpt-3-example) + - [Retro and InstructRetro](#retro-and-instructretro) + - [Mamba-based Language Models](#mamba-based-language-models) + - [Mixture of Experts](#mixture-of-experts) +- [Evaluation and Tasks](#evaluation-and-tasks) + - [GPT Text Generation](#gpt-text-generation) + - [Detoxify GPT via Self-generation](#detoxify-gpt-via-self-generation) + - [GPT Evaluation](#gpt-evaluation) + - [WikiText Perplexity Evaluation](#wikitext-perplexity-evaluation) + - [LAMBADA Cloze Accuracy](#lambada-cloze-accuracy) + - [BERT Task Evaluation](#bert-task-evaluation) + - [RACE Evaluation](#race-evaluation) + - [MNLI Evaluation](#mnli-evaluation) + - [Llama-2 Inference and Finetuning](#llama-2-inference-and-finetuning) +- [Model Optimization and Deployment](#model-optimization-and-deployment) + - [Quantization and TensorRT-LLM Deployment](#quantization-and-tensorrt-llm-deployment) +- [Datasets](#datasets) + - [Collecting Wikipedia Training Data](#collecting-wikipedia-training-data) + - [Collecting GPT Webtext Data](#collecting-gpt-webtext-data) +- [Reproducibility](#reproducibility) +- [Checkpoint conversion](#checkpoint-conversion) + - [Model class conversion](#model-class-conversion) + - [Checkpoint format conversion](#checkpoint-format-conversion) +- [Projects Using Megatron](#projects-using-megatron) + +# Megatron Overview +This repository comprises two essential components: **Megatron-LM** and **Megatron-Core**. Megatron-LM serves as a research-oriented framework leveraging Megatron-Core for large language model (LLM) training. Megatron-Core, on the other hand, is a library of GPU optimized training techniques that comes with formal product support including versioned APIs and regular releases. You can use Megatron-Core alongside Megatron-LM or [Nvidia NeMo Framework](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/nemo_megatron/mcore_customization.html) for an end-to-end and cloud-native solution. Alternatively, you can integrate Megatron-Core's building blocks into your preferred training framework. + +## Megatron-LM +First introduced in 2019, Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf), [2](https://arxiv.org/pdf/2104.04473.pdf), and [3](https://arxiv.org/pdf/2205.05198)) sparked a wave of innovation in the AI community, enabling researchers and developers to utilize the underpinnings of this library to further LLM advancements. Today, many of the most popular LLM developer frameworks have been inspired by and built directly leveraging the open-source Megatron-LM library, spurring a wave of foundation models and AI startups. Some of the most popular LLM frameworks built on top of Megatron-LM include [Colossal-AI](https://github.com/hpcaitech/ColossalAI), [HuggingFace Accelerate](https://github.com/huggingface/accelerate), and [NVIDIA NeMo Framework](https://www.nvidia.com/en-us/ai-data-science/generative-ai/nemo-framework/). A list of projects that have directly used Megatron can be found [here](#projects-using-megatron). + +## Megatron-Core +Megatron-Core is an open-source PyTorch-based library that contains GPU-optimized techniques and cutting-edge system-level optimizations. It abstracts them into composable and modular APIs, allowing full flexibility for developers and model researchers to train custom transformers at-scale on NVIDIA accelerated computing infrastructure. This library is compatible with all NVIDIA Tensor Core GPUs, including FP8 acceleration support for [NVIDIA Hopper architectures](https://www.nvidia.com/en-us/data-center/technologies/hopper-architecture/). + +Megatron-Core offers core building blocks such as attention mechanisms, transformer blocks and layers, normalization layers, and embedding techniques. Additional functionality like activation recomputation, distributed checkpointing is also natively built-in to the library. The building blocks and functionality are all GPU optimized, and can be built with advanced parallelization strategies for optimal training speed and stability on NVIDIA Accelerated Computing Infrastructure. Another key component of the Megatron-Core library includes advanced model parallelism techniques (tensor, sequence, pipeline, context, and MoE expert parallelism). + +Megatron-Core can be used with [NVIDIA NeMo](https://www.nvidia.com/en-us/ai-data-science/products/nemo/), an enterprise-grade AI platform. Alternatively, you can explore Megatron-Core with the native PyTorch training loop [here](https://github.com/NVIDIA/Megatron-LM/tree/main/examples). Visit [Megatron-Core documentation](https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html) to learn more. + + +# Training Speed and Scalability +Our codebase is capable of efficiently training large language models (i.e., models with hundreds of billions of parameters) with both model and data parallelism. To demonstrate how our software scales with multiple GPUs and model sizes, we consider GPT models ranging from 2 billion parameters to 462 billion parameters. All models use a vocabulary size of 131,072 and a sequence length of 4096. We vary hidden size, number of attention heads, and number of layers to arrive at a specific model size. As the model size increases, we also modestly increase batch size. Our experiments use up to 6144 [H100](https://www.nvidia.com/en-us/data-center/h100/) GPUs. We perform fine-grained overlapping of data-parallel (`--overlap-grad-reduce --overlap-param-gather`), tensor-parallel (`--tp-comm-overlap`) and pipeline-parallel communication (enabled by default) with computation to improve scalability. The reported throughputs are measured for end-to-end training and include all operations including data loading, optimizer steps, communication, and even logging. Note that we did not train these models to convergence. + +![Model table](images/model_table.png) + +Our weak scaled results show superlinear scaling (MFU increases from 41% for the smallest model considered to 47-48% for the largest models); this is because larger GEMMs have higher arithmetic intensity and are consequently more efficient to execute. + +![Weak scaling](images/weak_scaling.png) + +We also strong scaled the standard GPT-3 model (our version has slightly more than 175 billion parameters due to larger vocabulary size) from 96 H100 GPUs to 4608 GPUs, using the same batch size of 1152 sequences throughout. Communication becomes more exposed at larger scale, leading to a reduction in MFU from 47% to 42%. + +![Strong scaling](images/strong_scaling.png) -Below are some of the projects where we have directly used Megatron: -* [BERT and GPT Studies Using Megatron](https://arxiv.org/pdf/1909.08053.pdf) -* [BioMegatron: Larger Biomedical Domain Language Model](https://www.aclweb.org/anthology/2020.emnlp-main.379.pdf) -* [End-to-End Training of Neural Retrievers for Open-Domain Question Answering](https://arxiv.org/abs/2101.00408) -* [Large Scale Multi-Actor Generative Dialog Modeling](https://www.aclweb.org/anthology/2020.acl-main.8.pdf) -* [Local Knowledge Powered Conversational Agents](https://arxiv.org/abs/2010.10150) -* [MEGATRON-CNTRL: Controllable Story Generation with External Knowledge Using Large-Scale Language Models](https://www.aclweb.org/anthology/2020.emnlp-main.226.pdf) -* [RACE Reading Comprehension Dataset Leaderboard](http://www.qizhexie.com/data/RACE_leaderboard.html) -* [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf) -* [Few-shot Instruction Prompts for Pretrained Language Models to Detect Social Biases](https://arxiv.org/abs/2112.07868) -* [Exploring the Limits of Domain-Adaptive Training for Detoxifying Large-Scale Language Models](https://arxiv.org/abs/2202.04173) -* [Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model](https://arxiv.org/abs/2201.11990) -* [Multi-Stage Prompting for Knowledgeable Dialogue Generation](https://arxiv.org/abs/2203.08745) -* [Evaluating Parameter Efficient Learning for Generation](https://aclanthology.org/2022.emnlp-main.319.pdf) - -Megatron is also used in [NeMo Megatron](https://developer.nvidia.com/nvidia-nemo#nemo-megatron), a framework to help enterprises overcome the challenges of building and training sophisticated natural language processing models with billions and trillions of parameters. - -Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specific model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. Each cluster node has 8 NVIDIA 80GB A100 GPUs. The graph below shows that we scale nearly linear up to 1 trillion parameter models running on 3072 GPUs. Note that these results are from benchmark runs and these models were not trained to convergence; however, the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging. - -![Scaling Graph](images/Achieved_petaFLOPs.png) - -The following table shows both model (MFU) and hardware (HFU) FLOPs utilization for select configurations up to 1T parameters (see [our paper](https://arxiv.org/pdf/2205.05198) for a description of how these are calculated). As the model size increases, we achieve better GPU utilization and for the one trillion parameter model, we reach a MFU and HFU of 56.3% and 57.0%, respectively. Note that these numbers are also measured on benchmark runs and in this case are measured using a data parallel size of one. Data parallelism introduces some overhead due to the gradient all-reduce required between the data parallel groups. However, for large transformer models, this overhead is not large and can almost entirely eliminated by overlapping the gradient all-reduce with backpropagation. - -| Model Size | Model FLOPs Utilization | Hardware FLOPs Utilization | -| :---: | :---: | :---: | -| 22B | 41.5% | 43.7% | -| 175B | 51.4% | 52.8% | -| 530B | 56.0% | 57.0% | -| 1T | 56.3% | 57.0% | - -# Contents - * [Contents](#contents) - * [Setup](#setup) - * [Downloading Checkpoints](#downloading-checkpoints) - * [Usage](#usage) - * [Training](#training) - * [Data Preprocessing](#data-preprocessing) - * [BERT Pretraining](#bert-pretraining) - * [GPT Pretraining](#gpt-pretraining) - * [T5 Pretraining](#t5-pretraining) - * [Distributed Pretraining](#distributed-pretraining) - * [Activation Checkpointing and Recomputation](#activation-checkpointing-and-recomputation) - * [Distributed Optimizer](#distributed-optimizer) - * [FlashAttention](#flashattention) - * [GPT-3 Example](#gpt-3-example) - * [Retro](#retro) - * [Evaluation and Tasks](#evaluation-and-tasks) - * [GPT Text Generation](#gpt-text-generation) - * [GPT Evaluation](#gpt-evaluation) - * [WikiText Perplexity Evaluation](#wikitext-perplexity-evaluation) - * [LAMBADA Cloze Accuracy](#lambada-cloze-accuracy) - * [BERT Task Evaluation](#bert-task-evaluation) - * [RACE Evaluation](#race-evaluation) - * [MNLI Evaluation](#mnli-evaluation) - * [Llama-2 Inference and Finetuning](#llama-2-inference-and-finetuning) - * [Datasets](#datasets) - * [Collecting Wikipedia Training Data](#collecting-wikipedia-training-data) - * [Collecting GPT Webtext Data](#collecting-gpt-webtext-data) - * [Reproducibility](#reproducibility) # Setup We strongly recommend using the latest release of [NGC's PyTorch container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) with DGX nodes. If you can't use this for some reason, use the latest pytorch, cuda, nccl, and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start) releases. Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks. @@ -70,7 +100,7 @@ docker run --gpus all -it --rm -v /path/to/megatron:/workspace/megatron -v /path ``` ## Downloading Checkpoints -We have provided pretrained [BERT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_bert_345m) and [GPT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_lm_345m) checkpoints for use to evaluate or finetuning downstream tasks. To access these checkpoints, first [sign up](https://ngc.nvidia.com/signup) for and [setup](https://ngc.nvidia.com/setup/installers/cli) the NVIDIA GPU Cloud (NGC) Registry CLI. Further documentation for downloading models can be found in the [NGC documentation](https://docs.nvidia.com/dgx/ngc-registry-cli-user-guide/index.html#topic_6_4_1). +We have provided pretrained [BERT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_bert_345m) and [GPT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_lm_345m) checkpoints to evaluate or for finetuning downstream tasks. To access these checkpoints, first [sign up](https://ngc.nvidia.com/signup) for and [setup](https://ngc.nvidia.com/setup/installers/cli) the NVIDIA GPU Cloud (NGC) Registry CLI. Further documentation for downloading models can be found in the [NGC documentation](https://docs.nvidia.com/dgx/ngc-registry-cli-user-guide/index.html#topic_6_4_1). Alternatively, you can directly download the checkpoints using: @@ -92,7 +122,7 @@ After installation, there are several possible workflows. The most comprehensive However, steps 1 and 2 can be replaced by using one of the pretrained models mentioned above. -We've provided several scripts for pretraining both BERT and GPT in [`examples`](./examples) directory, as well as scripts for both zero-shot and fine-tuned downstream tasks including MNLI, RACE, WikiText103, and LAMBADA evaluation. There is also a script for GPT interactive text generation. +We've provided several scripts for pretraining both BERT and GPT in the [`examples`](./examples) directory, as well as scripts for both zero-shot and fine-tuned downstream tasks including MNLI, RACE, WikiText103, and LAMBADA evaluation. There is also a script for GPT interactive text generation. # Training ## Data Preprocessing @@ -139,27 +169,28 @@ Further command line arguments are described in the source file [`preprocess_dat ## BERT Pretraining -The [`examples/pretrain_bert.sh`](./examples/pretrain_bert.sh) script runs single GPU 345M parameter BERT pretraining. Debugging is the primary use for single GPU training, as the code base and command line arguments are optimized for highly distributed training. Most of the arguments are fairly self-explanatory. By default, the learning rate decays linearly over the training iterations starting at `--lr` to a minimum set by `--min-lr` over `--lr-decay-iters` iterations. The fraction of training iterations used for warmup is set by `--lr-warmup-fraction`. While this is single GPU training, the batch size specified by `--micro-batch-size` is a single forward-backward path batch-size and the code will perform gradient accumulation steps until it reaches `global-batch-size` which is the batch size per iteration. The data is partitioned into a 949:50:1 ratio for training/validation/test sets (default is 969:30:1). This partitioning happens on the fly, but is consistent across runs with the same random seed (1234 by default, or specified manually with `--seed`). We use `train-iters` as the training iterations requested. Alternatively, one can provide `--train-samples` which is total number of samples to train on. If this option is present, then instead of providing `--lr-decay-iters`, one will need to provide `--lr-decay-samples`. +The [`examples/bert/train_bert_340m_distributed.sh`](examples/bert/train_bert_340m_distributed.sh) script runs single GPU 345M parameter BERT pretraining. Debugging is the primary use for single GPU training, as the code base and command line arguments are optimized for highly distributed training. Most of the arguments are fairly self-explanatory. By default, the learning rate decays linearly over the training iterations starting at `--lr` to a minimum set by `--min-lr` over `--lr-decay-iters` iterations. The fraction of training iterations used for warmup is set by `--lr-warmup-fraction`. While this is single GPU training, the batch size specified by `--micro-batch-size` is a single forward-backward path batch-size and the code will perform gradient accumulation steps until it reaches `global-batch-size` which is the batch size per iteration. The data is partitioned into a 949:50:1 ratio for training/validation/test sets (default is 969:30:1). This partitioning happens on the fly, but is consistent across runs with the same random seed (1234 by default, or specified manually with `--seed`). We use `train-iters` as the training iterations requested. Alternatively, one can provide `--train-samples` which is total number of samples to train on. If this option is present, then instead of providing `--lr-decay-iters`, one will need to provide `--lr-decay-samples`. -The logging, checkpoint-saving, and evaluation intervals are specified. Checkpointing the activations facilitates the training of larger models and/or batches. Note that the `--data-path` now includes the additional `_text_sentence` suffix added in preprocessing, but does not include the file extensions. +The logging, checkpoint-saving, and evaluation interval options are specified. Note that the `--data-path` now includes the additional `_text_sentence` suffix added in preprocessing, but does not include the file extensions. -Further command line arguments are described in the source file [`arguments.py`](./megatron/arguments.py). +Further command line arguments are described in the source file [`arguments.py`](./megatron/training/arguments.py). -To run `examples/pretrain_bert.sh`, make any desired modifications including setting the environment variables for `CHECKPOINT_PATH`, `VOCAB_FILE`, and `DATA_PATH`. Make sure to set these variables to their paths in the container. Then launch the container with Megatron and necessary paths mounted (as explained in [Setup](#setup)) and run the example script. +To run `train_bert_340m_distributed.sh`, make any desired modifications including setting the environment variables for `CHECKPOINT_PATH`, `VOCAB_FILE`, and `DATA_PATH`. Make sure to set these variables to their paths in the container. Then launch the container with Megatron and necessary paths mounted (as explained in [Setup](#setup)) and run the example script. ## GPT Pretraining -The `examples/pretrain_gpt.sh` script runs single GPU 345M parameter GPT pretraining. As mentioned above, single GPU training is primarily intended for debugging purposes, as the code is optimized for distributed training. +The `examples/gpt3/train_gpt3_175b_distributed.sh` script runs single GPU 345M parameter GPT pretraining. As mentioned above, single GPU training is primarily intended for debugging purposes, as the code is optimized for distributed training. It follows largely the same format as the previous BERT script with a few notable differences: the tokenization scheme used is BPE (which requires a merge table and a `json` vocabulary file) instead of WordPiece, the model architecture allows for longer sequences (note that the max position embedding must be greater than or equal to the maximum sequence length), and the `--lr-decay-style` has been set to cosine decay. Note that the `--data-path` now includes the additional `_text_document` suffix added in preprocessing, but does not include the file extensions. -Further command line arguments are described in the source file [`arguments.py`](./megatron/arguments.py). +Further command line arguments are described in the source file [`arguments.py`](./megatron/training/arguments.py). -`examples/pretrain_gpt.sh` can be launched the same way as described for BERT. Set the env vars and make any other modifications, launch the container with appropriate mounts, and run the script. +`train_gpt3_175b_distributed.sh` can be launched the same way as described for BERT. Set the env vars and make any other modifications, launch the container with appropriate mounts, and run the script. +More details in [`examples/gpt3/README.md`](./examples/gpt3/README.md) ## T5 Pretraining -Very similar to BERT and GPT, the `examples/pretrain_t5.sh` script runs single GPU "base" (~220M parameter) T5 pretraining. The primary difference from BERT and GPT is the addition of the following arguments to accommodate the T5 architecture: +Very similar to BERT and GPT, the `examples/t5/train_t5_220m_distributed.sh` script runs single GPU "base" (~220M parameter) T5 pretraining. The primary difference from BERT and GPT is the addition of the following arguments to accommodate the T5 architecture: * `--kv-channels` sets the inner dimension of the "key" and "value" matrices of all attention mechanisms in the model. For BERT and GPT this defaults to the hidden size divided by the number of attention heads, but can be configured for T5. @@ -169,19 +200,19 @@ Very similar to BERT and GPT, the `examples/pretrain_t5.sh` script runs single G All of the other arguments remain as they were for BERT and GPT pretraining. Run this example with the same steps described above for the other scripts. +More details in [`examples/t5/README.md`](./examples/t5/README.md) + ## Distributed Pretraining -The `examples/pretrain_{bert,gpt,t5}_distributed.sh` scripts use the PyTorch distributed launcher for distributed training. As such, multi-node training can be achieved by properly setting environment variables. See the official PyTorch [documentation](https://pytorch.org/docs/stable/elastic/run.html#launcher-api) for further description of these [environment variables](https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization). By default, multi-node training uses the [nccl](https://developer.nvidia.com/nccl) distributed backend. A simple set of additional arguments and the use of the PyTorch distributed module with the `torchrun` elastic launcher (equivalent to `python -m torch.distributed.run`) are the only additional requirements to adopt distributed training. See any of `examples/pretrain_{bert,gpt,t5}_distributed.sh` for more details. +The `pretrain_{bert,gpt,t5}_distributed.sh` scripts use the PyTorch distributed launcher for distributed training. As such, multi-node training can be achieved by properly setting environment variables. See the official PyTorch [documentation](https://pytorch.org/docs/stable/elastic/run.html#launcher-api) for further description of these [environment variables](https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization). By default, multi-node training uses the [nccl](https://developer.nvidia.com/nccl) distributed backend. A simple set of additional arguments and the use of the PyTorch distributed module with the `torchrun` elastic launcher (equivalent to `python -m torch.distributed.run`) are the only additional requirements to adopt distributed training. See any of `pretrain_{bert,gpt,t5}_distributed.sh` for more details. -We use two types of parallelism: data and model parallelism. We facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time. +We use two types of parallelism: data and model parallelism. Our data parallelism implementation is in `megatron/core/distributed`, and supports overlapping of the gradient reduction with the backward pass when the `--overlap-grad-reduce` command-line option is used. -Second, we developed a simple and efficient two-dimensional model-parallel approach. To use tensor model parallelism (splitting execution of a single transformer module over multiple GPUs, see Section 3 of [our paper](https://arxiv.org/pdf/1909.08053.pdf)), add the `--tensor-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. To use sequence parallelism specify `--sequence-parallel`, which requires tensor model parallel as it split among the same GPUs (more details in Section 4.2.2 of [our paper](https://arxiv.org/pdf/2205.05198.pdf)). +Second, we developed a simple and efficient two-dimensional model-parallel approach. To use the first dimension, tensor model parallelism (splitting execution of a single transformer module over multiple GPUs, see Section 3 of [our paper](https://arxiv.org/pdf/1909.08053.pdf)), add the `--tensor-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. To use the second dimension, sequence parallelism, specify `--sequence-parallel`, which also requires tensor model parallelism to be enabled because it splits across the same GPUs (more details in Section 4.2.2 of [our paper](https://arxiv.org/pdf/2205.05198.pdf)). To use pipeline model parallelism (sharding the transformer modules into stages with an equal number of transformer modules on each stage, and then pipelining execution by breaking the batch into smaller microbatches, see Section 2.2 of [our paper](https://arxiv.org/pdf/2104.04473.pdf)), use the `--pipeline-model-parallel-size` flag to specify the number of stages to split the model into (e.g., splitting a model with 24 transformer layers across 4 stages would mean each stage gets 6 transformer layers each). - - -We have examples of how to use these two different forms of model parallelism the example scripts ending in `distributed_with_mp.sh`: +We have examples of how to use these two different forms of model parallelism the example scripts ending in `distributed_with_mp.sh`. Other than these minor changes, the distributed training is identical to the training on a single GPU. @@ -189,13 +220,15 @@ The interleaved pipelining schedule (more details in Section 2.2.2 of [our paper ## Activation Checkpointing and Recomputation -To reduce GPU memory usage so deploy a large model to a training system, we support activation checkpointing and recomputation. We support two levels of recompute granularity: `selective` and `full`. Selective recomputation is the default and recommended in almost all cases. It saves the activations that take less space and are expensive to recompute and recomputes activations that take a lot of space but are relatively cheap to recompute (see [our paper](https://arxiv.org/pdf/2205.05198) for details). To enable selective activation recompute simply use `--recompute-activations`. +To reduce GPU memory usage when training a large model, we support various forms of activation checkpointing and recomputation. Instead of all activations being stored in memory to be used during backprop, as was traditionally the case in deep learning models, only activations at certain "checkpoints" in the model are retained (or stored) in memory, and the other activations are recomputed on-the-fly when needed for backprop. Note that this kind of checkpointing, *activation* checkpointing, is very different from the checkpointing of model parameters and optimizer state, which is mentioned elsewhere. + +We support two levels of recompute granularity: `selective` and `full`. Selective recomputation is the default and is recommended in almost all cases. This mode retains in memory the activations that take less memory storage space and are more expensive to recompute and recomputes the activations that take more memory storage space but are relatively inexpensive to recompute. See [our paper](https://arxiv.org/pdf/2205.05198) for details. You should find that this mode maximizes performance while minimizing the memory required to store activations. To enable selective activation recompute simply use `--recompute-activations`. -For cases where memory is very tight, `full` checkpointing saves just the inputs to a transformer layer, or a block of transformer layers, and recomputes everything else. To turn on full activation recompute use `--recompute-granularity full`. When using full activation recomputation, there are two methods: `uniform` and `block`, chosen using the `--recompute-method` argument. +For cases where memory is very limited, `full` recompute saves just the inputs to a transformer layer, or a group, or block, of transformer layers, and recomputes everything else. To enable full activation recompute use `--recompute-granularity full`. When using `full` activation recompute, there are two methods: `uniform` and `block`, chosen using the `--recompute-method` argument. -* Uniform method uniformly divides the Transformer layers into groups of layers and stores the input activations of each group in the memory. The baseline group size is 1 and, in this case, the input activation of each Transformer layer is checkpointed. When the GPU memory is insufficient, increasing the number of layers per group reduces the memory usage thus enables running a bigger model. For example, when using the number of layers per group of 4, the input activation of each group of 4 Transformer layers is checkpointed. +* The `uniform` method uniformly divides the transformer layers into groups of layers (each group of size `--recompute-num-layers`) and stores the input activations of each group in memory. The baseline group size is 1 and, in this case, the input activation of each transformer layer is stored. When the GPU memory is insufficient, increasing the number of layers per group reduces the memory usage, enabling a bigger model to be trained. For example, when `--recompute-num-layers` is set to 4, only the input activation of each group of 4 transformer layers is stored. -* Block method checkpoints the input activations of a set number of individual Transformer layers per pipeline stage and do the rest of layers without any checkpointing. This method can be used to skip checkpointing some Transformer layers until the GPU memory is fully used, which is applicable only when there is unused GPU memory. Checkpointing fewer transformer layers avoids unnecessary activation recomputation in the backprop thus improves training performance. For example, when we specify 5 layers to checkpoint of 8 layers per pipeline stage, the input activations of only the first 5 Transformer layers are checkpointed and activation recomputation for the rest 3 layers is not needed in the backprop. +* The `block` method recomputes the input activations of a specific number (given by `--recompute-num-layers`) of individual transformer layers per pipeline stage and stores the input activations of the remaining layers in the pipeline stage. Reducing `--recompute-num-layers` results in storing the input activations to more transformer layers, which reduces the activation recomputation required in the backprop, thus improving training performance while increasing memory usage. For example, when we specify 5 layers to recompute of 8 layers per pipeline stage, the input activations of only the first 5 transformer layers are recomputed in the backprop step while the input activations for the final 3 layers are stored. `--recompute-num-layers` can be incrementally increased until the amount of memory storage space required is just small enough to fit in the available memory, thereby both maximally utilizing memory and maximizing performance. ## Distributed Optimizer @@ -212,6 +245,8 @@ Theoretical memory savings vary depending on the combination of the model's para | bf16 param, fp32 grads | 18 | 6 + 12/d | | fp32 param, fp32 grads | 16 | 8 + 8/d | +As with regular data parallelism, overlapping of the gradient reduction (in this case, a reduce-scatter) with the backward pass can be facilitated using the `--overlap-grad-reduce` flag. Additionally, overlapping of the parameter all-gather can be overlapped with the forward pass using `--overlap-param-gather`. + ## FlashAttention Usage: `--use-flash-attn`. Support attention head dimensions at most 128. @@ -227,23 +262,35 @@ pip install flash-attn ## GPT-3 Example -In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to configure Megatron to run [GPT-3](https://arxiv.org/abs/2005.14165) with 175 billion parameters on 1024 GPUs. The script is designed for [slurm](https://slurm.schedmd.com/documentation.html) with [pyxis](https://github.com/NVIDIA/pyxis) plugin but can be easily adopted to any other scheduler. It uses 8-way and 16-way tensor and pipeline parallelism, respectively. With options `global-batch-size 1536` and `rampup-batch-size 16 16 5859375`, the training will start with global batch size 16 and linearly increase the global batch size to 1536 over 5,859,375 samples with incremental steps 16. The training dataset can be either a single set or a multiple datasets combined with a set of weights. +In `examples/gpt3/train_gpt3_175b_distributed.sh` we have provided an example of how to configure Megatron to train [GPT-3](https://arxiv.org/abs/2005.14165) with 175 billion parameters on 1024 GPUs. The script is designed for [slurm](https://slurm.schedmd.com/documentation.html) with [pyxis](https://github.com/NVIDIA/pyxis) plugin but can be easily adopted to any other scheduler. It uses 8-way tensor parallelism and 16-way pipeline parallelism. With options `global-batch-size 1536` and `rampup-batch-size 16 16 5859375`, the training will start with global batch size 16 and linearly increase the global batch size to 1536 over 5,859,375 samples with incremental steps 16. The training dataset can be either a single set or a multiple datasets combined with a set of weights. With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs. +## Retro and InstructRetro + -## Retro +Retro [(Borgeaud et al., 2022)](https://arxiv.org/abs/2112.04426) is an autoregressive decoder-only language model (LM) pretrained with retrieval-augmentation. +Retro features practical scalability to support large-scale pretraining from scratch by retrieving from trillions of tokens. +Pretraining with retrieval provides a more efficient storage mechanism of factual knowledge, when compared to storing factual knowledge implicitly within the network's parameters, thus largely reducing model parameters while achieving lower perplexity than standard GPT. +Retro also provides the flexibility to update the +knowledge stored in LMs [(Wang et al., 2023a)](https://arxiv.org/abs/2304.06762) +by updating the retrieval database without training LMs again. -See: +InstructRetro [(Wang et al., 2023b)](https://arxiv.org/abs/2310.07713) further scales up the size of Retro to 48B, featuring the largest LLM pretrained with retrieval (as of December 2023). +The obtained foundation model, Retro 48B, largely outperforms the GPT counterpart in terms of perplexity. +With instruction tuning on Retro, InstructRetro demonstrates significant improvement over the instruction tuned GPT on downstream tasks in the zero-shot setting. Specifically, the average improvement of InstructRetro is 7% over its GPT counterpart across 8 short-form QA tasks, and 10% over GPT across 4 challenging long-form QA tasks. We also find that one can ablate the encoder from InstructRetro architecture and directly use the InstructRetro decoder backbone as GPT, while achieving comparable results. -- `tools/retro/README.md` for an overview. -- `tools/retro/examples/get_preprocess_cmd.sh` for an example of common preprocessing arguments. -- `tools/retro/examples/preprocess_data.sh` for an example of how to preprocess data. -- `tools/retro/examples/pretrain_model.sh` for an example of how to pretrain a model. +In this repo, we provide an end-to-end reproduction guide to implement Retro and InstructRetro, covering +- **Retrieval database construction**, which supports billions or even trillions of tokens as a large-scale retrieval database. +- **Pretraining with retrieval**, which supports pretraining from scratch and pretraining from a pretrained GPT model (Retro-fitting). +- **Instruction tuning**, where we provide an open-source instruction tuning dataset and the training recipe for instruction tuning on Retro. +- **Downstream task evaluation**, where we provide the text generation and evaluation scripts for zero-shot question answering tasks. -Retro is a retrieval-enhanced model that is based on GPT. As described in [Improving language models by retrieving from trillions of tokens](https://arxiv.org/abs/2112.04426), Retro retrieves from a database of document chunks by performing locality search using a sample's tokens. The retrieval database can be large -- often billions or even trillions of tokens -- and provides a more efficient storage mechanism of factual knowledge, when compared to storing factual knowledge implicitly within the network's parameters. +See [tools/retro/README.md](tools/retro/README.md) for a detailed overview. -Using Retro requires two steps: 1) preprocessing the retrieval database and pretraining neighbors, and 2) pretraining a model using this data. Please see `tools/retro/README.md` for a detailed overview. +## Mamba-based Language Models + +See [examples/mamba](./examples/mamba) for details. +## Mixture of Experts +MoE (Mixture of Experts) is a powerful LLM architecture implemented in the Megatron-Core framework, designed to enhance the efficiency and scalability of large language models. It leverages **Expert Parallelism**, allowing multiple experts to be distributed across different workers, where each worker processes distinct batches of training samples. This method significantly increases computational throughput, enabling models to achieve high performance metrics, such as 47% MFU during BF16 training for 8x7B on H100. + +Key Features of MoE: +- **Parallelism Techniques**: MoE combines various parallelism strategies, including Expert Parallelism, Data Parallelism, Tensor Parallelism, Sequence Paralleism, Pipeline Parallelism, and Context Parallelism. This combination allows for handling larger model variants effectively. +- **Router and Load Balancing**: The system employs advanced routing mechanisms like the Top-K router and utilizes load balancing algorithms to optimize token distribution among experts. +- **Performance Optimizations**: Techniques such as GroupedGEMM and FP8 training enhance the efficiency of MoE models, particularly when multiple experts are involved. +- **Token Dispatch Mechanism**: MoE supports both dropless and token drop strategies to manage token distribution effectively across experts. + +For a comprehensive overview of MoE training configurations and optimizations, please refer to the detailed README located at [megatron/core/transformer/moe/README.md](./megatron/core/transformer/moe/README.md). + # Evaluation and Tasks We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning. @@ -332,7 +390,7 @@ We provide several command line arguments, detailed in the scripts listed below, Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on fewer GPUs in downstream tasks. The following script accomplishes this. This example reads in a GPT model with 4-way tensor and 4-way pipeline model parallelism and writes out a model with 2-way tensor and 2-way pipeline model parallelism.
-python tools/checkpoint/util.py \
+python tools/checkpoint/convert.py \
         --model-type GPT \
         --load-dir checkpoints/gpt3_tp4_pp4 \
         --save-dir checkpoints/gpt3_tp2_pp2 \
@@ -345,7 +403,7 @@ Several downstream tasks are described for both GPT and BERT models below. They
 
 ## GPT Text Generation
 
-We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`and `top-p`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.
+We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`and `top-p`. See `--help` or the source file for more information. See [examples/inference/run_text_generation_server_345M.sh](examples/inference/run_text_generation_server_345M.sh) for an example of how to run the server.
 
 Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.
 
@@ -359,12 +417,12 @@ You can also use CURL or any other tools to query the server directly:
 curl 'http://localhost:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8'  -d '{"prompts":["Hello world"], "tokens_to_generate":1}'
 
-See [megatron/text_generation_server.py](megatron/text_generation_server.py) for more API options. +See [megatron/inference/text_generation_server.py](megatron/inference/text_generation_server.py) for more API options. ### Detoxify GPT via Self-generation -We include an example in `examples/detxoify_lm/` to detoxify language models by leveraging the generative power of language models. +We include an example in `examples/academic_paper_scripts/detxoify_lm/` to detoxify language models by leveraging the generative power of language models. -See [examples/detxoify_lm/README.md](examples/detxoify_lm/README.md) for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus. +See [examples/academic_paper_scripts/detxoify_lm/README.md](examples/academic_paper_scripts/detxoify_lm/README.md) for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus. ## GPT Evaluation @@ -407,7 +465,7 @@ python tasks/main.py \ ### LAMBADA Cloze Accuracy To compute LAMBADA cloze accuracy (the accuracy of predicting the last token given the preceding tokens) we utilize a detokenized, processed version of the [LAMBADA dataset](https://github.com/cybertronai/bflm/blob/master/lambada_test.jsonl). -We use the following command to run LAMBADA evaluation on a 345M parameter model. Note that the `--strict-lambada` flag should be used to require whole word matching. Make that `lambada` is part of the file path. +We use the following command to run LAMBADA evaluation on a 345M parameter model. Note that the `--strict-lambada` flag should be used to require whole word matching. Ensure that `lambada` is part of the file path.
 TASK="LAMBADA"
@@ -503,7 +561,13 @@ python tasks/main.py \
 
 The Llama-2 [family of models](https://ai.meta.com/llama/) are an open-source set of pretrained & finetuned (for chat) models that have achieved strong results across a wide set of benchmarks. At the time of release, Llama-2 models achieved among the best results for open-source models, and were competitive with the closed-source GPT-3.5 model (see https://arxiv.org/pdf/2307.09288.pdf).
 
-The Llama-2 checkpoints can be loaded into Megatron for inference and finetuning. See documentation [here](docs/llama2.md).
+The Llama-2 checkpoints can be loaded into Megatron for inference and finetuning. See documentation [here](docs/llama_mistral.md).
+
+# Model Optimization and Deployment
+Megatron-Core (MCore) `GPTModel` family supports advanced quantization algorithms and high-performance inference through TensorRT-LLM.
+
+## Quantization and TensorRT-LLM Deployment
+See [Megatron Model Optimization and Deployment](examples/inference/quantization/README.md) for `llama2` and `nemotron3` examples.
 
 # Datasets
 We do not host any datasets for GPT or BERT training, however, we detail their collection so that our results may be reproduced.
@@ -511,16 +575,93 @@ We do not host any datasets for GPT or BERT training, however, we detail their c
 ## Collecting Wikipedia Training Data
 We recommend following the Wikipedia data extraction process specified by Google research: "the recommended pre-processing is to download [the latest dump](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), extract the text with [WikiExtractor.py](https://github.com/attardi/wikiextractor), and then apply any necessary cleanup to convert it into plain text."
 
-We recommend using the `--json` argument when using WikiExtractor, which will dump the Wikipedia data into loose json format (one json per line), making it more manageable on the file system and also readily consumable by our codebase. We recommend further preprocessing this json dataset by nltk punctuation standardization. For BERT training, use the `--split-sentences` flag to `preprocess_data.py` as described [above](#data-preprocessing) to include sentence breaks in the produced index. If you'd like to use Wikipedia data for GPT training you should still clean it with nltk/spacy/ftfy, but do not use the `--split-sentences` flag.
+We recommend using the `--json` argument when using WikiExtractor, which will dump the Wikipedia data into loose json format (one json object per line), making it more manageable on the file system and also readily consumable by our codebase. We recommend further preprocessing this json dataset with nltk punctuation standardization. For BERT training, use the `--split-sentences` flag to `preprocess_data.py` as described [above](#data-preprocessing) to include sentence breaks in the produced index. If you'd like to use Wikipedia data for GPT training you should still clean it with nltk/spacy/ftfy, but do not use the `--split-sentences` flag.
 
 ## Collecting GPT Webtext Data
-We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/openwebtext) library from [jcpeterson](https://github.com/jcpeterson/openwebtext) and [eukaryote31's](https://github.com/eukaryote31/openwebtext) work to download urls. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in our [openwebtext](./tools/openwebtext) directory. For reddit URLs corresponding to content up to October 2018 we arrived at approximately 37GB of content.
+We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/openwebtext) library from [jcpeterson](https://github.com/jcpeterson/openwebtext) and [eukaryote31's](https://github.com/eukaryote31/openwebtext) work to download urls. We then filter, clean, and deduplicate all downloaded content according to the procedure described in our [openwebtext](./tools/openwebtext) directory. For reddit URLs corresponding to content up to October 2018 we arrived at approximately 37GB of content.
 
 # Reproducibility
-Megatron training is intended to be bitwise reproducible. This means that the same training config run twice in the same HW and SW environment should produce identical model checkpoints, losses and accuracy metric values (iteration time metrics may vary).
+Megatron training can be bitwise reproducible; to enable this mode use `--deterministic-mode`. This means that the same training config run twice in the same HW and SW environment should produce identical model checkpoints, losses and accuracy metric values (iteration time metrics may vary).
+
+There are currently three known Megatron optimizations that break reproducibility whilst still producing almost identical training runs:
+1. The specific NCCL algorithm that is used during an all-reduce (as specified by the environment variable `NCCL_ALGO`) is important. We have tested the following: `^NVLS`, `Tree`, `Ring`, `CollnetDirect`, `CollnetChain`. The code admits the use of `^NVLS`, which allows NCCL the choice of non-NVLS algorithms; its choice seems to be stable.
+2. Flash attention is non-deterministic; do not use `--use-flash-attn`.
+3. If using Transformer Engine, you must also set the environment variable `NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`.
+
+In addition, determinisim has only been verified in NGC PyTorch containers up to and newer than 23.12. If you observe nondeterminism in Megatron training under other circumstances please open an issue.
+
+# Checkpoint conversion
+
+We support two forms of model conversion:
+
+1. Model class conversion (i.e., the `GPTModel` in `model.legacy` vs. `model.core`)
+2. Checkpoint format conversion (i.e., distributed vs. non-distributed checkpoint)
 
-There are currently two known Megatron optimizations that break reproducibility whilst still producing almost identical training runs. The following workarounds should be applied in cases where reproducibility is required:
-1. When training using `--bf16`, reproducbility is only obtained when the checkpointing and resume schedule of training is identical. If the checkpointing schedule will change, i.e. checkpointing and resume will occur at different iterations, the option `--no-bias-gelu-fusion` should be used.
-2. Flash attention is non-deterministic. If reproducibility is required do not use `--use-flash-attn`.
+## Model class conversion
 
-These sources of non-determinism are under active investigation. If you observe non-determinism in Megatron training under other circumstances please open an issue.
+Megatron supports converting between different model classes, including internal model classes (we currently have the older `legacy` models, and the newer `core` models) and external model classes (such as Meta, Huggingface, Mistral, and Mixtral models). Additionally, during this conversion, one can update the parallel state of the model (i.e., changing tensor and pipeline model parallelism).
+
+ We provide the tool `tools/checkpoint/convert.py` to convert between model classes. Some important arguments include:
+
+- `--model-type`: `GPT` or `BERT`
+- `--loader`: format of the existing checkpoint. Supported formats include:
+  - `legacy`: our older model classes (under `megatron.legacy.model`)
+  - `core`: our newer model classes (under `megatron.core.models`)
+  - `llama_mistral`: for loading Llama and Mistral models (supports Meta and Huggingface formats)
+  - `mixtral_hf`: for loading Mixtral models (Huggingface only)
+- `--load-dir`: directory for loading the existing checkpoint
+- `--saver`: `legacy` or `core` (see descriptions under `--loader`)
+- `--save-dir`: directory for saving the new checkpoint
+- `--target-tensor-parallel-size`: new tensor model parallel size
+- `--target-pipeline-parallel-size`: new pipeline model parallel size
+
+For more argument details, please see the main script (`convert.py`), loader scripts (`loader_core.py`, `loader_legacy.py`, `loader_llama_mistral.py`, `loader_mixtral_hf.py`), or saver scripts (`saver_core.py`, `saver_legacy.py`).
+
+An example command for converting a GPT model from the old format (`legacy`) to the new format (`core`) would look as follows:
+
+```
+python tools/checkpoint/convert.py \
+>   --model-type GPT \
+>   --loader legacy \
+>   --load-dir ${LEGACY_FORMAT_DIR} \
+>   --saver core \
+>   --save-dir ${CORE_FORMAT_DIR} \
+>   --target-tensor-parallel-size ${TP} \
+>   --target-pipeline-parallel-size ${PP} \
+```
+
+For examples of converting Llama/Mistral models into Megatron, please see [here](docs/llama_mistral.md).
+
+## Checkpoint format conversion
+
+Megatron offers multiple checkpoint formats, including:
+
+- `torch`: Basic checkpoint format with sequential read & writes, and is tied to a specific tensor/pipeline model parallel state (TP/PP states, respectively). (While a specific checkpoint is tied to a specific TP/PP state, a checkpoint can still be manually converted via the model class converter described above).
+- `torch_dist`: Distributed checkpoint format, for fast parallel reads & writes, and also is parallel state agnostic (i.e., one can load the same checkpoint to different TP/PP setups).
+
+Generally speaking, `torch_dist` is the more modern and recommended checkpoint format due to its speed. However, depending on the use case, it may be desirable to convert between these two formats. To do so, launch your *training* script (e.g., via `pretrain_gpt.py`) as you normally would, but with two additional arguments:
+
+- `--ckpt-convert-format ${FORMAT}`: `${FORMAT}` can be one of `torch` or `torch_dist`, as described above.
+- `--ckpt-convert-save ${PATH_TO_SAVE_NEW_FORMAT}`: this path should be different than your existing `--load`/`--save` paths, to avoid overwriting the existing checkpoint. After converting, use this new path for your `--load`/`--save` paths.
+
+The general idea of this checkpoint format converter is that it launches the model just as one normally would for training, but before running any training iterations, it saves to the new checkpoint format, and then exits. It is important to note that all other launch args should remain the same, in order for the system to understand the previous checkpoint format.
+
+# Projects Using Megatron
+Below are some of the projects where we have directly used Megatron:
+* [BERT and GPT Studies Using Megatron](https://arxiv.org/pdf/1909.08053.pdf)
+* [BioMegatron: Larger Biomedical Domain Language Model](https://www.aclweb.org/anthology/2020.emnlp-main.379.pdf)
+* [End-to-End Training of Neural Retrievers for Open-Domain Question Answering](https://arxiv.org/abs/2101.00408)
+* [Large Scale Multi-Actor Generative Dialog Modeling](https://www.aclweb.org/anthology/2020.acl-main.8.pdf)
+* [Local Knowledge Powered Conversational Agents](https://arxiv.org/abs/2010.10150)
+* [MEGATRON-CNTRL: Controllable Story Generation with External Knowledge Using Large-Scale Language Models](https://www.aclweb.org/anthology/2020.emnlp-main.226.pdf)
+* [RACE Reading Comprehension Dataset Leaderboard](http://www.qizhexie.com/data/RACE_leaderboard.html)
+* [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf)
+* [Few-shot Instruction Prompts for Pretrained Language Models to Detect Social Biases](https://arxiv.org/abs/2112.07868)
+* [Exploring the Limits of Domain-Adaptive Training for Detoxifying Large-Scale Language Models](https://arxiv.org/abs/2202.04173)
+* [Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model](https://arxiv.org/abs/2201.11990)
+* [Multi-Stage Prompting for Knowledgeable Dialogue Generation](https://arxiv.org/abs/2203.08745)
+* [Evaluating Parameter Efficient Learning for Generation](https://aclanthology.org/2022.emnlp-main.319.pdf)
+* [Exploring the Limits of Domain-Adaptive Training for Detoxifying Large-Scale Language Models](https://arxiv.org/abs/2202.04173)
+* [Shall We Pretrain Autoregressive Language Models with Retrieval? A Comprehensive Study](https://arxiv.org/abs/2304.06762)
+* [InstructRetro: Instruction Tuning post Retrieval-Augmented Pretraining](https://arxiv.org/abs/2310.07713)
+* [An Empirical Study of Mamba-based Language Models](https://arxiv.org/abs/2406.07887)
diff --git a/compute_memory_usage.py b/compute_memory_usage.py
deleted file mode 100644
index ca6e3aacde..0000000000
--- a/compute_memory_usage.py
+++ /dev/null
@@ -1,79 +0,0 @@
-from megatron.initialize import initialize_megatron
-from megatron import get_args
-
-
-def compute_weight_and_optimizer_memory(args):
-    assert args.sequence_parallel
-    num_parameters_in_transformer_layers = (
-        10
-        * args.num_layers
-        * args.hidden_size
-        * args.hidden_size
-        * (
-            1
-            + (args.num_query_groups / (5.0 * args.num_attention_heads))
-            + (2 / (5 * args.hidden_size))
-            + (1 / (5 * args.num_layers * args.hidden_size))
-        )
-    )
-    embedding_size = args.hidden_size * args.padded_vocab_size
-    if args.untie_embeddings_and_output_weights:
-        num_parameters_with_embeddings = num_parameters_in_transformer_layers + (2 * embedding_size)
-    else:
-        num_parameters_with_embeddings = num_parameters_in_transformer_layers + embedding_size
-    print(f"Number of parameters in billions: {num_parameters_with_embeddings / 10**9:.2f}")
-
-    # Most loaded model shard has (1/pp_size transformer layers + 1 embedding layer) / tp_size.
-    num_parameters_on_most_loaded_model_shard = (
-        (num_parameters_in_transformer_layers / args.pipeline_model_parallel_size) + embedding_size
-    ) / args.tensor_model_parallel_size
-    # Other shards just have (1/pp_size transformer layers) / tp_size.
-    num_parameters_on_other_model_shards = num_parameters_in_transformer_layers / (
-        args.pipeline_model_parallel_size * args.tensor_model_parallel_size
-    )
-
-    print(
-        f"Number of parameters in most loaded shard in billions: {num_parameters_on_most_loaded_model_shard / 10**9:.4f}"
-    )
-    print(
-        f"Number of parameters in other shards in billions: {num_parameters_on_other_model_shards / 10**9:.4f}"
-    )
-
-    num_bytes_per_parameter = (
-        18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size)
-    )
-    return num_parameters_on_most_loaded_model_shard * num_bytes_per_parameter
-
-
-def compute_activation_memory(args):
-    # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
-    assert args.recompute_granularity == 'selective'
-    activation_memory = (
-        args.seq_length * args.micro_batch_size * args.hidden_size * args.num_layers
-    ) * 34
-
-    # Multiply by interleaved PP memory factor.
-    activation_memory *= 1 + (
-        (args.pipeline_model_parallel_size - 2)
-        / (args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size)
-    )
-    return activation_memory / args.tensor_model_parallel_size
-
-
-def compute_total_memory(args):
-    weight_and_optimizer_memory = compute_weight_and_optimizer_memory(args)
-    activation_memory = compute_activation_memory(args)
-    total_memory = weight_and_optimizer_memory + activation_memory
-    print(
-        f"(DP size, PP size, TP size) = {(args.data_parallel_size, args.pipeline_model_parallel_size, args.tensor_model_parallel_size)}, "
-        f"Weight and optimizer memory: {weight_and_optimizer_memory / (1024 * 1024):.2f} MB, "
-        f"Activation memory: {activation_memory / (1024 * 1024):.2f} MB, "
-        f"Total memory: {total_memory / (1024 * 1024):.2f} MB\n"
-    )
-
-
-if __name__ == "__main__":
-    initialize_megatron(allow_no_cuda=True, skip_mpu_initialization=True)
-    args = get_args()
-
-    compute_total_memory(args)
diff --git a/docs/distrib_optimizer.md b/docs/distrib_optimizer.md
deleted file mode 100644
index def23b20eb..0000000000
--- a/docs/distrib_optimizer.md
+++ /dev/null
@@ -1,54 +0,0 @@
-# Distributed Optimizer
-
-The motivation for the distributed optimizer is to save memory by distributing the optimizer state evenly across data parallel ranks, versus the current method of replicating the optimizer state across data parallel ranks. As described in https://arxiv.org/abs/1910.02054, this branch specifically implements the following:
-
-- [yes] distribute all 'non-overlapping' optimizer state (i.e., model params already in fp32 are NOT distributed)
-- [no] distribute model gradients
-- [no] distribute model parameters
-
-Theoretical memory savings vary depending on the combination of the model's param dtype and grad dtype. In the current implementation, the theoretical number of bytes per parameter is (where 'd' is the data parallel size):
-
-|        | Non-distributed optim | Distributed optim |
-| ------ | ------ | ------ |
-| float16 param, float16 grads | 20 | 4 + 16/d |
-| float16 param, fp32 grads    | 18 | 6 + 12/d |
-| fp32 param, fp32 grads       | 16 | 8 + 8/d  |
-
-The implementation of the distributed optimizer is centered on using the contiguous grad buffer for communicating grads & params between the model state and the optimizer state. The grad buffer at any given moment either holds:
-
-1. all model grads
-2. a 1/d size _copy_ of the main grads (before copying to the optimizer state)
-3. a 1/d size _copy_ of the main params (after copying from the optimizer state)
-4. all model params
-5. zeros (or None), between iterations
-
-The grad buffer is used for performing reduce-scatter and all-gather operations, for passing grads & params between the model state and optimizer state. With this implementation, no dynamic buffers are allocated.
-
-The figures below illustrate the grad buffer's sharding scheme, and the key steps of the distributed optimizer's param update:
-
-## Data flow
-
-![Data flow](images/distrib_optimizer/data_flow.png)
-
-## Sharding scheme
-
-![Sharding scheme](images/distrib_optimizer/sharding_scheme.png)
-
-## Key steps
-
-_(note: using illustrations above, and assuming fp16 grads)_
-
-- Backward pass finishes (grad buffer holds 16 fp16 grad elements)
-- Call reduce-scatter on each DP rank
-- Each DP rank now has 4 elements within the grad buffer that are fully reduced (remaining 12 elements are garbage)
-- Each DP rank copies its relevant 4 fp16 grad elements from the grad buffer into 4 fp32 main grad elements (separate buffer, owned by the optimizer); i.e.
-  - DP rank 0 copies elements [0:4]
-  - DP rank 1 copies elements [4:8]
-  - DP rank 2 copies elements [8:12]
-  - DP rank 3 copies elements [12:16]
-- Optimizer.step()
-- Each DP rank copies its 4 fp32 main (/optimizer) param elements into the corresponding 4 fp16 elements in the grad buffer
-- Call all-gather on each DP rank
-- Grad buffer now contains all 16, fully updated, fp16 model param elements
-- Copy updated model params from grad buffer into their respective param tensors
-- (At this point, grad buffer is ready to be zero'd for the next iteration)
diff --git a/docs/images/distrib_optimizer/data_flow.png b/docs/images/distrib_optimizer/data_flow.png
deleted file mode 100644
index d48fc134c4..0000000000
Binary files a/docs/images/distrib_optimizer/data_flow.png and /dev/null differ
diff --git a/docs/images/distrib_optimizer/sharding_scheme.png b/docs/images/distrib_optimizer/sharding_scheme.png
deleted file mode 100644
index b07c25b05f..0000000000
Binary files a/docs/images/distrib_optimizer/sharding_scheme.png and /dev/null differ
diff --git a/docs/llama2.md b/docs/llama2.md
deleted file mode 100644
index 9043a2b95d..0000000000
--- a/docs/llama2.md
+++ /dev/null
@@ -1,171 +0,0 @@
-# Llama-2 Inference and Finetuning
-
-The Llama-2 [family of models](https://ai.meta.com/llama/) are an open-source set of pretrained & finetuned (for chat) models that have achieved strong results across a wide set of benchmarks. At the time of release, Llama-2 models achieved among the best results for open-source models, and were competitive with the closed-source GPT-3.5 model (see https://arxiv.org/pdf/2307.09288.pdf).
-
-Llama-2 checkpoints can be loaded into Megatron for inference and for finetuning. Loading these checkpoints consists of three steps:
-
-1. Get access to download the checkpoints.
-2. Convert the checkpoints from Meta/Huggingface format to Megatron format.
-3. Setup arguments for launching the model.
-
-The following sections detail these steps. The final section lists benchmark result comparisons between: 1) Llama-2 inference code running the Meta-format checkpoints, and 2) Megatron inference code running the converted checkpoints.
-
-# Contents
-  * [Download Meta or Huggingface checkpoints](#download-meta-or-huggingface-checkpoints)
-  * [Convert checkpoint format](#convert-checkpoint-format)
-    * [Meta format](#meta-format)
-    * [Huggingface format](#huggingface-format)
-  * [Launch model](#launch-model)
-    * [Megatron](#launch-megatron)
-    * [Meta](#launch-meta)
-    * [Huggingface](#launch-hf)
-  * [Benchmark results](#benchmark-results)
-
-# Download Meta or Huggingface checkpoints
-
-Users must first apply for access to download the Llama-2 checkpoints either directly from [Meta](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) or through [Huggingface](https://huggingface.co/docs/transformers/main/model_doc/llama2) (HF). The checkpoints are available in two formats, Meta's native format (available from both the Meta and HF links), and HF's format (available only from HF). Either format can be converted to Megatron, as detailed next.
-
-# Convert checkpoint format
-
-Depending on which checkpoint format is downloaded (Meta or HF), one or two steps must be taken to convert to Megatron format.
-
-### Meta format
-
-The Meta format checkpoints must first be converted to HF format before converting to Megatron format. The `transformers` package is required for the first step, and must have version >=4.31.0 (e.g., `pip install transformers>=4.31.0`). (**Note**: we have specifically tested with versions `4.31.0` and `4.32.0`; your experience may vary with newer versions.) Assuming the downloaded checkpoints are in `$CHECKPOINT_DIR` (with separate sub-directories for 7B, 13B, 70B, etc.), the following example command can be used to convert from Llama-2 format to HF format:
-
-```
-$>: python $LIB_DIR/transformers/models/llama/convert_llama_weights_to_hf.py \
- >    --input_dir $LLAMA_FORMAT_DIR \
- >    --output_dir $HF_FORMAT_DIR \
- >    --model_size 7B`
-```
-
-Valid values for `--model_size` include `7B`, `13B`, and `70B` (for pretrained-only models), and `7Bf`, `13Bf`, and `70Bf` (for chat-finetuned models). Use `python convert_llama_weights_to_hf.py --help` for additional argument details. Once the checkpoints have been converted to HF format, proceed to the Huggingface format section below.
-
-### Huggingface format
-
-The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-2 checkpoint converter for HF format (see script `tools/checkpoint/loader_llama2_hf.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values:
-
-| Model size | Tensor parallel size (`TP`) |
-| ---------- | --------------------------- |
-|  7B        | 1                           |
-| 13B        | 2                           |
-| 70B        | 8                           |
-
-Using these values for `TP`, along with the path to the Llama-2 tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format:
-
-```
-$>: python tools/checkpoint/util.py \
- >    --model-type GPT \
- >    --loader llama2_hf \
- >    --saver megatron \
- >    --target-tensor-parallel-size ${TP} \
- >    --load-dir ${HF_FORMAT_DIR} \
- >    --save-dir ${MEGATRON_FORMAT_DIR} \
- >    --tokenizer-model ${TOKENIZER_MODEL}
-```
-
-After this conversion, we are ready to load the checkpoints into a Megatron GPT model.
-
-# Launch model
-
-### Launch Megatron
-
-If loading for either inference or finetuning, use the following arguments:
-
-```
---tensor-model-parallel-size ${TP} \
---pipeline-model-parallel-size 1 \
---seq-length 4096 \
---max-position-embeddings 4096 \
---tokenizer-type Llama2Tokenizer \
---tokenizer-model ${TOKENIZER_MODEL} \
---load ${CHECKPOINT_DIR} \
---exit-on-missing-checkpoint \
---use-checkpoint-args \
---no-load-optim \
---no-load-rng \
---fp16 \
---untie-embeddings-and-output-weights \
---use-rotary-position-embeddings \
---normalization RMSNorm \
---no-position-embedding \
---no-masked-softmax-fusion \
---no-query-key-layer-scaling \
-```
-
-### Launch Meta
-
-Meta checkpoints can be launched with: https://github.com/facebookresearch/llama
-
-### Launch Huggingface
-
-Huggingface checkpoints can be launched with: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
-
-# Benchmark results
-
-The tables below list the benchmark comparisons between native Llama-2 (using Meta's checkpoint and Meta's inference code) and Megatron (using a converted HF checkpoint and Megatron's inference code).
-
-The values are the percent error between Megatron and Llama-2, calculated using the formula: `| - | / `, where the type of score is detailed before each table. Across all tests (80 total per model size), the mean error is 0.15%. The small difference in benchmark scores between the two models is due to minor arithmetic differences in implementation that alter the numerics slightly. Some of the factors that influence this difference include:
-
-- Megatron performs batch matrix multiplications in a couple places, such as within self attention and in SwiGLU, that Llama performs separately.
-- Megatron uses `torch.baddbmm` within self attention, versus Llama using `torch.matmul`.
-- Megatron uses a `sin`/`cos` implementation for rotary position embeddings, versus Llama using a `polar`/`complex` implementation.
-- Llama calls `torch.set_default_dtype(torch.float16)` during initialization, which Megatron does not.
-
-### Big Bench
-
-Score type: multiple choice grade.
-
-| bigbench / standard | 7b | 13b | 70b |
-| -- | -- | -- | -- |
-| date_understanding | 0.29% | 0.13% | 0.12% |
-| general_knowledge | 0.00% | 0.00% | 0.00% |
-| human_organs_senses | 0.00% | 0.00% | 0.00% |
-| intent_recognition | 0.00% | 0.11% | 0.00% |
-| riddle_sense | 0.00% | 0.00% | 0.00% |
-| similarities_abstraction | 0.00% | 0.58% | 0.00% |
-| simple_arithmetic_json_multiple_choice | 0.00% | 0.00% | 0.00% |
-| undo_permutation | 0.19% | 0.19% | 0.18% |
-
-### Multilingual
-
-Score type: multiple choice grade.
-
-| multilingual / xcopa | 7b  | 13b  | 70b |
-| -- | -- | -- | -- |
-| en-template-mGPT-remove-punctuation | 0.08% | 0.00% | 0.00% |
-| et-template-mGPT-remove-punctuation | 0.00% | 0.13% | 0.25% |
-| ht-template-mGPT-remove-punctuation | 0.26% | 0.13% | 0.26% |
-| id-template-mGPT-remove-punctuation | 0.11% | 0.00% | 0.19% |
-| it-template-mGPT-remove-punctuation | 0.00% | 0.10% | 0.09% |
-| qu-template-mGPT-remove-punctuation | 0.00% | 0.00% | 0.27% |
-| sw-template-mGPT-remove-punctuation | 0.14% | 0.13% | 0.13% |
-| th-template-mGPT-remove-punctuation | 0.25% | 0.13% | 0.13% |
-| tr-template-mGPT-remove-punctuation | 0.26% | 0.00% | 0.34% |
-| vi-template-mGPT-remove-punctuation | 0.00% | 0.11% | 0.00% |
-| zh-template-mGPT-remove-punctuation | 0.00% | 0.10% | 0.09% |
-
-### LM Evaluation Harness
-
-Score type: multiple choice grade.
-
-| lm-eval | 7b  | 13b  | 70b |
-| -- | -- | -- | -- |
-| boolq | 0.04% | 0.04% | 0.07% |
-| hellaswag | 0.02% | 0.03% | 0.03% |
-| piqa | 0.00% | 0.00% | 0.07% |
-| winogrande | 0.00% | 0.11% | 0.20% |
-
-### MMLU
-
-Score type: multiple choice grade.
-
-Note: the number in brackets is the number of sub-tasks for each supercategory.
-
-| mmlu | 7b  | 13b  | 70b |
-| -- | -- | -- | -- |
-| stem [18]  | 0.79% | 0.05% | 0.01% |
-| humanities [13]  | 0.19% | 0.01% | 0.02% |
-| other (business, health, misc.) [14]  | 0.08% | 0.06% | 0.12% |
-| social sciences [12]  | 0.37% | 0.21% | 0.01% |
diff --git a/docs/llama_mistral.md b/docs/llama_mistral.md
new file mode 100644
index 0000000000..5dd61866e8
--- /dev/null
+++ b/docs/llama_mistral.md
@@ -0,0 +1,444 @@
+# Llama, Mistral and other Llama-like model support in Megatron-LM
+
+NOTE: In order to simplify code we now only support converting llama-3.x and mistral checkpoints downloaded from Huggingface.
+
+The [Llama-2](https://ai.meta.com/llama/) and [Llama-3.x](https://llama.meta.com/) family of models are an open-source set of pretrained & finetuned (for chat) models that have achieved strong results across a wide set of benchmarks. At their times of release, both Llama-2 and Llama-3 models achieved among the best results for open-source models, and were competitive with leading closed-source models (see https://arxiv.org/pdf/2307.09288.pdf and https://ai.meta.com/blog/meta-llama-3/).
+
+Similarly, [Mistral-7b](https://mistral.ai/news/announcing-mistral-7b/) is an open-source model with pretrained and finetuned (for chat) variants that achieve strong benchmark results.
+
+Architecturally Llama-2, Llama-3 and Mistral-7b are very similar. As such Megatron can support loading checkpoints from all three for inference and finetuning. Converting the checkpoints and loading them is slightly different for each model and is detailed for each below.
+
+# Contents
+
+- [Llama, Mistral and other Llama-like model support in Megatron-LM](#llama-mistral-and-other-llama-like-model-support-in-megatron-lm)
+- [Contents](#contents)
+- [Llama-2](#llama-2)
+  - [Download Meta or Huggingface checkpoints](#download-meta-or-huggingface-checkpoints)
+  - [Convert checkpoint format](#convert-checkpoint-format)
+    - [Meta format](#meta-format)
+    - [Huggingface format](#huggingface-format)
+  - [Launch model](#launch-model)
+    - [Launch Megatron](#launch-megatron)
+    - [Launch Meta](#launch-meta)
+    - [Launch Huggingface](#launch-huggingface)
+  - [Benchmark results](#benchmark-results)
+    - [Big Bench](#big-bench)
+    - [Multilingual](#multilingual)
+    - [LM Evaluation Harness](#lm-evaluation-harness)
+    - [MMLU](#mmlu)
+- [Llama-3.x](#llama-3x)
+  - [Download Huggingface checkpoints](#download-huggingface-checkpoints)
+  - [Convert checkpoint format](#convert-checkpoint-format-1)
+    - [Huggingface format](#huggingface-format-1)
+  - [(Optional) Validate checkpoints](#optional-validate-checkpoints)
+  - [Launch model](#launch-model-1)
+- [Mistral-7b](#mistral-7b)
+  - [Download Huggingface checkpoints](#download-huggingface-checkpoints-2)
+  - [Convert checkpoint format](#convert-checkpoint-format-3)
+  - [(Optional) Validate checkpoints](#optional-validate-checkpoints-2)
+  - [Launch model](#launch-model-3)
+- [Other Llama-like model support](#other-llama-like-model-support)
+- [Known numerical differences](#known-numerical-differences)
+- [Using legacy model format](#using-legacy-model-format)
+
+
+# Llama-2
+
+Llama-2 checkpoints can be loaded into Megatron for inference and for finetuning. Loading these checkpoints consists of three steps:
+
+1. Get access to download the checkpoints.
+2. Convert the checkpoints from Meta/Huggingface format to Megatron format.
+3. Setup arguments for launching the model.
+
+The following sections detail these steps. The final section lists benchmark result comparisons between: 1) Llama-2 inference code running the Meta-format checkpoints, and 2) Megatron inference code running the converted checkpoints.
+
+## Download Meta or Huggingface checkpoints
+
+Users must first apply for access to download the Llama-2 checkpoints either directly from [Meta](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) or through [Huggingface](https://huggingface.co/docs/transformers/main/model_doc/llama2) (HF). The checkpoints are available in two formats, Meta's native format (available from both the Meta and HF links), and HF's format (available only from HF). Either format can be converted to Megatron, as detailed next.
+
+## Convert checkpoint format
+
+We recommend passing `--dtype bf16` for training or finetuning. Inference can be done in bfloat16 or float16.
+
+### Meta format
+
+The Meta format checkpoints are converted to HF format as an intermediate step before converting to Megatron format. The `transformers` package is required, and must have version >=4.31.0 (e.g., `pip install transformers>=4.31.0`). (**Note**: we have specifically tested with versions `4.31.0` and `4.32.0`; your experience may vary with newer versions.) Assuming the downloaded checkpoints are in `$CHECKPOINT_DIR` (with separate sub-directories for 7B, 13B, 70B, etc.), the following example command can be used to convert from Llama-2 format to HF format in bfloat16:
+
+```
+python tools/checkpoint/convert.py \
+>   --model-type GPT \
+>   --loader llama_mistral \
+>   --load-dir ${META_FORMAT_DIR} \
+>   --model-size ${MODEL_SIZE} \
+>   --checkpoint-type meta \
+>   --tokenizer-model ${TOKENIZER_MODEL} \
+>   --saver core \
+>   --save-dir ${MEGATRON_FORMAT_DIR} \
+>   --target-tensor-parallel-size ${TP} \
+>   --target-pipeline-parallel-size ${PP} \
+>   --bf16
+```
+
+Valid values for `--model-size` are `llama2-7B`, `llama2-13B`, and `llama2-70B` (for pretrained-only models), and `llama2-7Bf`, `llama2-13Bf`, and `llama2-70Bf` (for chat-finetuned models).
+
+### Huggingface format
+
+The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-2 checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values:
+
+| Model size | Tensor parallel size (`TP`) |
+| ---------- | --------------------------- |
+|  7B        | 1                           |
+| 13B        | 2                           |
+| 70B        | 8                           |
+
+Using these values for `TP`, along with the path to the Llama-2 tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format:
+
+```
+python tools/checkpoint/convert.py \
+>   --model-type GPT \
+>   --loader llama_mistral \
+>   --load-dir ${HF_FORMAT_DIR} \
+>   --model-size ${MODEL_SIZE} \
+>   --checkpoint-type hf \
+>   --tokenizer-model ${TOKENIZER_MODEL} \
+>   --saver core \
+>   --save-dir ${MEGATRON_FORMAT_DIR} \
+>   --target-tensor-parallel-size ${TP} \
+>   --target-pipeline-parallel-size ${PP} \
+>   --bf16
+```
+
+After this conversion, we are ready to load the checkpoints into a Megatron GPT model.
+
+## Launch model
+
+### Launch Megatron
+
+If loading for either inference or finetuning, use the following arguments:
+
+```
+--tensor-model-parallel-size ${TP} \
+--pipeline-model-parallel-size 1 \
+--seq-length 4096 \
+--max-position-embeddings 4096 \
+--tokenizer-type Llama2Tokenizer \
+--tokenizer-model ${TOKENIZER_MODEL} \
+--load ${CHECKPOINT_DIR} \
+--exit-on-missing-checkpoint \
+--use-checkpoint-args \
+--no-load-optim \
+--no-load-rng \
+--untie-embeddings-and-output-weights \
+--use-rotary-position-embeddings \
+--normalization RMSNorm \
+--no-position-embedding \
+--no-masked-softmax-fusion \
+--attention-softmax-in-fp32
+```
+
+**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format).
+
+### Launch Meta
+
+Meta checkpoints can be launched with: https://github.com/facebookresearch/llama
+
+### Launch Huggingface
+
+Huggingface checkpoints can be launched with: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+
+## Benchmark results
+
+The tables below list the benchmark comparisons between native Llama-2 (using Meta's checkpoint and Meta's inference code) and Megatron (using a converted HF checkpoint and Megatron's inference code).
+
+The values are the percent error between Megatron and Llama-2, calculated using the formula: `| - | / `, where the type of score is detailed before each table. Across all tests (80 total per model size), the mean error is 0.15%. The small difference in benchmark scores between the two models is due to minor arithmetic differences in implementation that alter the numerics slightly. Some of the factors that influence this difference include:
+
+- Megatron performs batch matrix multiplications in a couple places, such as within self attention and in SwiGLU, that Llama performs separately.
+- Megatron uses `torch.baddbmm` within self attention, versus Llama using `torch.matmul`.
+- Megatron uses a `sin`/`cos` implementation for rotary position embeddings, versus Llama using a `polar`/`complex` implementation.
+- Llama calls `torch.set_default_dtype(torch.float16)` during initialization, which Megatron does not.
+
+### Big Bench
+
+Score type: multiple choice grade.
+
+| bigbench / standard | 7b | 13b | 70b |
+| -- | -- | -- | -- |
+| date_understanding | 0.29% | 0.13% | 0.12% |
+| general_knowledge | 0.00% | 0.00% | 0.00% |
+| human_organs_senses | 0.00% | 0.00% | 0.00% |
+| intent_recognition | 0.00% | 0.11% | 0.00% |
+| riddle_sense | 0.00% | 0.00% | 0.00% |
+| similarities_abstraction | 0.00% | 0.58% | 0.00% |
+| simple_arithmetic_json_multiple_choice | 0.00% | 0.00% | 0.00% |
+| undo_permutation | 0.19% | 0.19% | 0.18% |
+
+### Multilingual
+
+Score type: multiple choice grade.
+
+| multilingual / xcopa | 7b  | 13b  | 70b |
+| -- | -- | -- | -- |
+| en-template-mGPT-remove-punctuation | 0.08% | 0.00% | 0.00% |
+| et-template-mGPT-remove-punctuation | 0.00% | 0.13% | 0.25% |
+| ht-template-mGPT-remove-punctuation | 0.26% | 0.13% | 0.26% |
+| id-template-mGPT-remove-punctuation | 0.11% | 0.00% | 0.19% |
+| it-template-mGPT-remove-punctuation | 0.00% | 0.10% | 0.09% |
+| qu-template-mGPT-remove-punctuation | 0.00% | 0.00% | 0.27% |
+| sw-template-mGPT-remove-punctuation | 0.14% | 0.13% | 0.13% |
+| th-template-mGPT-remove-punctuation | 0.25% | 0.13% | 0.13% |
+| tr-template-mGPT-remove-punctuation | 0.26% | 0.00% | 0.34% |
+| vi-template-mGPT-remove-punctuation | 0.00% | 0.11% | 0.00% |
+| zh-template-mGPT-remove-punctuation | 0.00% | 0.10% | 0.09% |
+
+### LM Evaluation Harness
+
+Score type: multiple choice grade.
+
+| lm-eval | 7b  | 13b  | 70b |
+| -- | -- | -- | -- |
+| boolq | 0.04% | 0.04% | 0.07% |
+| hellaswag | 0.02% | 0.03% | 0.03% |
+| piqa | 0.00% | 0.00% | 0.07% |
+| winogrande | 0.00% | 0.11% | 0.20% |
+
+### MMLU
+
+Score type: multiple choice grade.
+
+Note: the number in brackets is the number of sub-tasks for each supercategory.
+
+| mmlu | 7b  | 13b  | 70b |
+| -- | -- | -- | -- |
+| stem [18]  | 0.79% | 0.05% | 0.01% |
+| humanities [13]  | 0.19% | 0.01% | 0.02% |
+| other (business, health, misc.) [14]  | 0.08% | 0.06% | 0.12% |
+| social sciences [12]  | 0.37% | 0.21% | 0.01% |
+
+# Llama-3.x
+
+Llama-3.x checkpoints can be loaded into Megatron for inference and for finetuning. Loading these checkpoints consists of several steps:
+
+1. Get access to download the checkpoints (weights and tokenizer).
+2. Convert the checkpoints from Huggingface format to Megatron format.
+3. (Optional) Validate converted checkpoints
+4. Setup arguments for launching the model.
+
+The following sections detail these steps.
+
+## Download Huggingface checkpoints
+
+Users must first apply for access to download the Llama-3.x checkpoints from [Huggingface](https://huggingface.co/meta-llama).
+
+## Convert checkpoint format
+
+We recommend passing `--dtype bf16` for training or finetuning. Inference can be done in bfloat16 or float16.
+
+### Huggingface format
+
+The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-3.x checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values:
+
+| Model size | Tensor parallel size (`TP`) |
+| ---------- | --------------------------- |
+|  1B        | 1                           |
+|  3B        | 1                           |
+|  8B        | 1                           |
+| 70B        | 8                           |
+
+Using these values for `TP`, along with the path to the Llama-3.x tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format:
+
+```
+$>: python tools/checkpoint/convert.py \
+ >    --bf16 \
+ >    --model-type GPT \
+ >    --loader llama_mistral \
+ >    --saver core \
+ >    --target-tensor-parallel-size ${TP} \
+ >    --checkpoint-type hf \
+ >    --load-dir ${HF_FORMAT_DIR} \
+ >    --save-dir ${MEGATRON_FORMAT_DIR} \
+ >    --tokenizer-model ${TOKENIZER_MODEL} \
+ >    --model-size llama3 \
+```
+
+After this conversion, we are ready to load the checkpoints into a Megatron GPT model.
+
+## (Optional) Validate checkpoints
+
+A Megatron-LM text generation server for Llama3 can be launched using the script `examples/inference/llama_mistral/run_text_generation_llama3.sh  `. For Llama3.1, please use `examples/inference/llama_mistral/run_text_generation_llama3.1.sh`.
+
+Once running, query the server with `curl 'http://:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8'  -d '{"prompts":[""], "tokens_to_generate":100, "top_k":1}'`.
+
+A reference generation for comparison can be obtained from the Huggingface transformers library by running `python examples/llama_mistral/huggingface_reference.py --model_path  --prompt `.
+
+## Launch model
+
+If loading for either inference or finetuning, use the following arguments for Llama 3.0:
+
+```
+--tensor-model-parallel-size ${TP} \
+--pipeline-model-parallel-size 1 \
+--seq-length 8192 \
+--max-position-embeddings 8192 \
+--tokenizer-type HuggingFaceTokenizer \
+--tokenizer-model ${TOKENIZER_MODEL} \
+--load ${CHECKPOINT_DIR} \
+--exit-on-missing-checkpoint \
+--use-checkpoint-args \
+--no-load-optim \
+--no-load-rng \
+--untie-embeddings-and-output-weights \
+--normalization RMSNorm \
+--position-embedding-type rope \
+--no-masked-softmax-fusion \
+--attention-softmax-in-fp32 \
+--disable-bias-linear \
+--transformer-impl transformer_engine \
+--group-query-attention 8 \
+--attention-dropout 0.0 \
+--hidden-dropout 0.0 \
+--rotary-base 500000 \
+--rotary-percent 1.0 \
+--ffn-hidden-size 14336 \
+--num-attention-heads 32 \
+--swiglu \
+--bf16 \
+```
+
+For Llama3.1 please use the following arguments:
+
+```
+--tensor-model-parallel-size ${TP} \
+--pipeline-model-parallel-size 1 \
+--seq-length 8192 \
+--max-position-embeddings 131072 \
+--tokenizer-type HuggingFaceTokenizer \
+--tokenizer-model ${TOKENIZER_MODEL} \
+--load ${CHECKPOINT_DIR} \
+--exit-on-missing-checkpoint \
+--use-checkpoint-args \
+--no-load-optim \
+--no-load-rng \
+--untie-embeddings-and-output-weights \
+--normalization RMSNorm \
+--position-embedding-type rope \
+--no-masked-softmax-fusion \
+--attention-softmax-in-fp32 \
+--disable-bias-linear \
+--transformer-impl transformer_engine \
+--group-query-attention 8 \
+--attention-dropout 0.0 \
+--hidden-dropout 0.0 \
+--rotary-base 500000 \
+--rotary-percent 1.0 \
+--use-rope-scaling \
+--ffn-hidden-size 14336 \
+--num-attention-heads 32 \
+--swiglu \
+--bf16 \
+```
+
+**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format).
+
+# Mistral-7b
+
+Megatron currently supports loading the v0.3 release of Mistral-7b (which does not use sliding window attention and offers a larger 32768 vocabulary) for inference and finetuning. Loading these checkpoints consists of several steps:
+
+1. Get access to download the checkpoints (weights and tokenizer).
+2. Convert the checkpoints from HuggingFace format to Megatron format.
+3. (Optional) Validate converted checkpoints
+4. Setup arguments for launching the model.
+
+The following sections detail these steps.
+
+## Download Huggingface checkpoints
+
+Users must first apply for access to download the Mistral-7b checkpoints through [Huggingface](https://huggingface.co/mistralai/Mistral-7B-v0.3) (HF).
+
+## Convert checkpoint format
+
+The HF checkpoints can be converted to Megatron format by using Megatron's own Mistral checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`).
+
+Using the path to the Mistral tokenizer model (downloaded alongside the HF checkpoint), run the following command from the root of your Megatron source code to convert from HF format to the Megatron core format:
+
+```
+$>: python tools/checkpoint/convert.py \
+ >    --bf16 \
+ >    --model-type GPT \
+ >    --loader llama_mistral \
+ >    --saver core \
+ >    --target-tensor-parallel-size ${TP} \
+ >    --checkpoint-type hf \
+ >    --load-dir ${HF_FORMAT_DIR} \
+ >    --save-dir ${MEGATRON_FORMAT_DIR} \
+ >    --tokenizer-model ${TOKENIZER_MODEL} \
+ >    --model-size mistral \
+```
+
+After this conversion, we are ready to load the checkpoints into a Megatron core GPT model.
+
+## (Optional) Validate checkpoints
+
+A Megatron-LM text generation server for Mistral-7B can be launched using the script `examples/inference/llama_mistral/run_text_generation_mistral.sh  `.
+
+Once running, query the server with `curl 'http://:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8'  -d '{"prompts":[""], "tokens_to_generate":100, "top_k":1}'`.
+
+A reference generation for comparison can be obtained from the Huggingface transformers library by running `python examples/inference/llama_mistral/huggingface_reference.py --model_path  --prompt `.
+
+## Launch model
+
+If loading for either inference or finetuning, use the following arguments:
+
+```
+--tensor-model-parallel-size ${TP} \
+--pipeline-model-parallel-size 1 \
+--seq-length 4096 \
+--max-position-embeddings 4096 \
+--tokenizer-type HuggingFaceTokenizer \
+--tokenizer-model ${TOKENIZER_MODEL} \
+--load ${CHECKPOINT_DIR} \
+--exit-on-missing-checkpoint \
+--use-checkpoint-args \
+--no-load-optim \
+--no-load-rng \
+--untie-embeddings-and-output-weights \
+--normalization RMSNorm \
+--position-embedding-type rope \
+--no-masked-softmax-fusion \
+--attention-softmax-in-fp32
+--apply-layernorm-1p \
+--transformer-impl transformer_engine \
+--group-query-attention 8 \
+--disable-bia-linear \
+--rotary-base 1000000 \
+--rotary-percent 1.0 \
+--swiglu \
+--ffn-hidden-size 14336 \
+--num-attention-heads 32
+```
+
+**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format).
+
+# Other Llama-like model support
+
+*Note: Experimental*
+
+Many models such as Yi-34B and Qwen2.x use the Llama architecture and may be converted from HuggingFace to Megatron using the commands in [Llama-3.x](#llama-3x).
+
+# Known numerical differences
+
+It is not expected that the megatron and Huggingface implementations of llama3.x and mistral models will produce numerically identical results. There are multiple points where small numerical differences are expected. This is a non-exhaustive list:
+
+1. TransformerEngine (TE) uses the model params_dtype inside RMSNorm whereas the Huggingface implementation uses fp32. See for details: https://github.com/NVIDIA/TransformerEngine/issues/1132
+2. Huggingface `transformers` implements the q, k and v projections in self-attention as separate GEMMs whereas Megatron core combines them into a single GEMM for efficiency. This leads to small numerical differences.
+
+# Using legacy model format
+
+In all the checkpoint conversion examples used in this document, the saver format `--saver core` is used, signifying that the newer (and recommended) Megatron GPT model class will be used. I.e.:
+
+- old class: `megatron.legacy.model.gpt_model.GPTModel`
+- new class: `megatron.core.models.gpt.gpt_model.GPTModel`
+
+Using this new format is the recommended approach. However, if your use case requires using the older class (i.e., convert using `--saver legacy`), then when launching training or finetuning, the following args must be added:
+
+- `--use-legacy-models`: use the older model class
+- `--ckpt-format torch`: use the `torch` checkpoint format, which is the only checkpoint format that is compatible with the legacy model format
diff --git a/docs/source/api-guide/context_parallel.rst b/docs/source/api-guide/context_parallel.rst
new file mode 100644
index 0000000000..c08defd210
--- /dev/null
+++ b/docs/source/api-guide/context_parallel.rst
@@ -0,0 +1,35 @@
+context\_parallel package
+=========================
+
+Context parallelism overview 
+----------------------------
+
+.. figure:: ../images/context_parallel/CP_overview.png
+   :alt: cp_overview
+   :align: center
+   
+   Figure 1: A transformer layer running with TP2CP2. Communications next to Attention are for CP, others are for TP. (AG/RS: all-gather in forward and reduce-scatter in backward, RS/AG: reduce-scatter in forward and all-gather in backward, /AG: no-op in forward and all-gather in backward).
+
+Context Parallelism ("CP") is a parallelization scheme on the dimension of sequence length. Unlike prior SP (sequence parallelism) which only splits the sequence of Dropout and LayerNorm activations, CP partitions the network inputs and all activations along sequence dimension. With CP, all modules except attention (e.g., Linear, LayerNorm, etc.) can work as usual without any changes, because they do not have inter-token operations. As for attention, the Q (query) of each token needs to compute with the KV (key and value) of all tokens in the same sequence. Hence, CP requires additional all-gather across GPUs to collect the full sequence of KV. Correspondingly, reduce-scatter should be applied to the activation gradients of KV in backward propagation. To reduce activation memory footprint, each GPU only stores the KV of a sequence chunk in forward and gathers KV again in backward. KV communication happens between a GPU and its counterparts in other TP groups. The all-gather and reduce-scatter are transformed to point-to-point communications in ring topology under the hood. Exchanging KV also can leverage MQA/GQA to reduce communication volumes, as they only have one or few attention heads for KV.
+
+For example, in Figure 1, assuming sequence length is 8K, each GPU processes 4K tokens. GPU0 and GPU2 compose a CP group, they exchange KV with each other. Same thing also happens between GPU1 and GPU3. CP is similar to `Ring Attention `_ but provides better performance by (1) leveraging the latest OSS and cuDNN flash attention kernels; (2) removing unnecessary computation resulted from low-triangle causal masking and achieving optimal load balance among GPUs.
+
+Context parallelism benefits 
+----------------------------
+
+.. figure:: ../images/context_parallel/CP_results.png
+   :alt: cp_results
+   :align: center
+   
+   Figure 2: Speedup of 175B GPT with various TP+CP combinations vs. full recompute (i.e., TP8CP1).
+
+LLM encounters OOM (out of memory) issue with long context (i.e., long sequence length) because of linearly increasing memory footprint of activations. Recomputing activations in backward can avoid OOM but also introduce significant overheads (~30% with full recompute). Enlarging TP (tensor model parallelism) can fix the OOM issue as well, but it potentially makes compute (e.g., Linear) too short to overlap communication latencies. To be clear, scaling out to more GPUs with bigger TP can hit the overlapping problem no matter if OOM happens.
+
+CP can better address the issues. With CP, each GPU only computes on a part of the sequence, which reduces both computation and communication by CP times. Therefore, there are no concerns about the overlapping between them. The activation memory footprint per GPU is also CP times smaller, hence no OOM issue anymore. As Figure 2 shows, the combinations of TP and CP can achieve optimal performance by eliminating recompute overheads and making the best tradeoff between computation and communications.
+
+Enabling context parallelism
+----------------------------
+
+CP support has been added to GPT. All models that share GPT code path also should be able to benefit from CP, such as Llama. CP can work with TP (tensor model parallelism), PP (pipeline model parallelism), and DP (data parallelism), where the total number of GPUs equals TPxCPxPPxDP. CP also can work with different attention variants, including MHA/MQA/GQA, uni-directional and bi-directional masking.
+
+CP is enabled by simply setting context_parallel_size= in command line. Default context_parallel_size is 1, which means CP is disabled. Running with CP requires Megatron-Core (>=0.5.0) and Transformer Engine (>=1.1).
diff --git a/docs/source/api-guide/datasets.rst b/docs/source/api-guide/datasets.rst
new file mode 100644
index 0000000000..247a3f07d3
--- /dev/null
+++ b/docs/source/api-guide/datasets.rst
@@ -0,0 +1,104 @@
+datasets package
+================
+
+.. mdinclude :: ../../../megatron/core/datasets/readme.md
+
+Submodules
+----------
+
+datasets.blended\_megatron\_dataset\_config module
+---------------------------------------------------
+
+.. automodule:: core.datasets.blended_megatron_dataset_config
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+datasets.blended\_megatron\_dataset\_builder module
+---------------------------------------------------
+
+.. automodule:: core.datasets.blended_megatron_dataset_builder
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+datasets.megatron\_tokenizer module
+-----------------------------------
+
+.. automodule:: core.datasets.megatron_tokenizer
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+datasets.indexed\_dataset module
+--------------------------------
+
+.. automodule:: core.datasets.indexed_dataset
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+datasets.megatron\_dataset module
+---------------------------------
+
+.. automodule:: core.datasets.megatron_dataset
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+datasets.gpt\_dataset module
+----------------------------
+
+.. automodule:: core.datasets.gpt_dataset
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+datasets.masked\_dataset module
+-------------------------------
+
+.. automodule:: core.datasets.masked_dataset
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+datasets.bert\_dataset module
+-----------------------------
+
+.. automodule:: core.datasets.bert_dataset
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+datasets.t5\_dataset module
+---------------------------
+
+.. automodule:: core.datasets.t5_dataset
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+datasets.blended\_dataset module
+----------------------------------
+
+.. automodule:: core.datasets.blended_dataset
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+datasets.utils module
+---------------------
+
+.. automodule:: core.datasets.utils
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: core.datasets
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
diff --git a/docs/source/api-guide/dist_checkpointing.rst b/docs/source/api-guide/dist_checkpointing.rst
new file mode 100644
index 0000000000..7e384a08a3
--- /dev/null
+++ b/docs/source/api-guide/dist_checkpointing.rst
@@ -0,0 +1,79 @@
+dist\_checkpointing package
+===========================
+
+A library for saving and loading the distributed checkpoints.
+A "distributed checkpoint" can have various underlying formats (current default format is based on Zarr)
+but has a distinctive property - the checkpoint saved in one parallel configuration (tensor/pipeline/data parallelism)
+can be loaded in a different parallel configuration.
+
+Using the library requires defining sharded state_dict dictionaries with functions from  *mapping* and *optimizer* modules.
+Those state dicts can be saved or loaded with a *serialization* module using strategies from *strategies* module.
+
+
+Subpackages
+-----------
+
+.. toctree::
+   :maxdepth: 4
+
+   dist_checkpointing.strategies
+
+Submodules
+----------
+
+dist\_checkpointing.serialization module
+----------------------------------------
+
+.. automodule:: core.dist_checkpointing.serialization
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+dist\_checkpointing.mapping module
+----------------------------------
+
+.. automodule:: core.dist_checkpointing.mapping
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+dist\_checkpointing.optimizer module
+------------------------------------
+
+.. automodule:: core.dist_checkpointing.optimizer
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+dist\_checkpointing.core module
+-------------------------------
+
+.. automodule:: core.dist_checkpointing.core
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+dist\_checkpointing.dict\_utils module
+--------------------------------------
+
+.. automodule:: core.dist_checkpointing.dict_utils
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+
+dist\_checkpointing.utils module
+--------------------------------
+
+.. automodule:: core.dist_checkpointing.utils
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: core.dist_checkpointing
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/api-guide/dist_checkpointing.strategies.rst b/docs/source/api-guide/dist_checkpointing.strategies.rst
new file mode 100644
index 0000000000..41e674c761
--- /dev/null
+++ b/docs/source/api-guide/dist_checkpointing.strategies.rst
@@ -0,0 +1,50 @@
+dist\_checkpointing.strategies package
+======================================
+
+Package defining different checkpoint formats (backends) and saving/loading algorithms (strategies).
+
+Strategies can be used for implementing new checkpoint formats or implementing new (more optimal for a given use case) ways of saving/loading of existing formats.
+Strategies are passed to `dist_checkpointing.load` and `dist_checkpointing.save` functions and control the actual saving/loading procedure.
+
+Submodules
+----------
+
+dist\_checkpointing.strategies.base module
+------------------------------------------
+
+.. automodule:: core.dist_checkpointing.strategies.base
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+dist\_checkpointing.strategies.tensorstore module
+-------------------------------------------------
+
+.. automodule:: core.dist_checkpointing.strategies.tensorstore
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+dist\_checkpointing.strategies.two\_stage module
+------------------------------------------------
+
+.. automodule:: core.dist_checkpointing.strategies.two_stage
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+dist\_checkpointing.strategies.zarr module
+------------------------------------------
+
+.. automodule:: core.dist_checkpointing.strategies.zarr
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: core.dist_checkpointing.strategies
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/api-guide/dist_optimizer.md b/docs/source/api-guide/dist_optimizer.md
new file mode 100644
index 0000000000..34f42d5343
--- /dev/null
+++ b/docs/source/api-guide/dist_optimizer.md
@@ -0,0 +1,40 @@
+# Distributed Optimizer
+
+The motivation for the distributed optimizer is to save memory by distributing the optimizer state evenly across data parallel ranks (https://arxiv.org/abs/1910.02054), versus the naive method of replicating the optimizer state across data parallel ranks.
+
+Theoretical memory savings vary depending on the combination of the datatype of the model's parameters (`param_dtype`) and main gradients accumulated across data-parallel replicas (`grad_dtype`). We always use `fp32` main parameters for optimizer steps. In the current implementation, the theoretical number of bytes per parameter is (where d is the data parallel size):
+
+|        | Non-distributed optim | Distributed optim |
+| ------ | ------ | ------ |
+| `fp16` parameters, `fp16` gradients | 20 | 4 + 16/d |
+| `bf16` parameters, `fp32` gradients    | 18 | 6 + 12/d |
+| `fp32` parameters, `fp32` gradients       | 16 | 8 + 8/d  |
+
+Our implementation of the distributed optimizer uses contiguous buffers for parameters and main gradients; model gradients are copied over to the main gradients as soon as they are fully computed.
+
+The figures below illustrate the distributed optimizer's sharding scheme, and the key steps of the distributed optimizer's parameter update:
+
+## Data flow
+
+![Data flow](../images/distrib_optimizer/data_flow.png)
+
+## Sharding scheme
+
+![Sharding scheme](../images/distrib_optimizer/sharding_scheme.png)
+
+## Key steps
+
+_(note: using illustrations above, assuming `bf16` model weights, `bf16` model gradients that are computed by the backward pass and `fp32` main gradients that are also used for optimizer steps; we always use `fp32` main weights for optimizer steps)_
+
+- Backward pass finishes (gradient buffer holds 16 `fp32` gradient elements).
+- Call reduce-scatter on each DP rank.
+- Each DP rank now has 4 elements within the gradient buffer that are fully reduced (remaining 12 elements are garbage).
+  - DP rank 0 has gradient values for elements [0:4].
+  - DP rank 1 has gradient values for elements [4:8].
+  - DP rank 2 has gradient values for elements [8:12].
+  - DP rank 3 has gradient values for elements [12:16].
+- Optimizer.step().
+- Each DP rank copies its 4 `fp32` main parameter elements into the corresponding `bf16` parameter buffer (each element is cast from fp32 to fp16).
+- Call all-gather on each DP rank.
+- The parameter buffer now contains all 16, fully updated, `bf16` model parameter elements. Parameters in PyTorch modules already point to the appropriate locations in this parameter buffer, and thus forward passes are ready to run after the all-gather completes.
+- At this point, the gradient buffer is also ready to be zero'd for the next iteration.
diff --git a/docs/source/api-guide/distributed.rst b/docs/source/api-guide/distributed.rst
new file mode 100644
index 0000000000..737820331c
--- /dev/null
+++ b/docs/source/api-guide/distributed.rst
@@ -0,0 +1,53 @@
+distributed package
+===================
+
+This package contains various utilities to finalize model weight gradients
+on each rank before the optimizer step. This includes a distributed data
+parallelism wrapper to all-reduce or reduce-scatter the gradients across
+data-parallel replicas, and a `finalize\_model\_grads` method to
+synchronize gradients across different parallelism modes (e.g., 'tied'
+layers on different pipeline stages, or gradients for experts in a MoE on
+different ranks due to expert parallelism).
+
+Submodules
+----------
+
+distributed.distributed\_data\_parallel
+---------------------------------------
+
+Model wrapper for distributed data parallelism. Stores gradients in a
+contiguous buffer, and supports the option of overlapping communication
+(all-reduce or reduce-scatter) with backprop computation by breaking up
+full model's gradients into smaller buckets and running all-reduce /
+reduce-scatter on each bucket asynchronously. 
+
+.. automodule:: core.distributed.distributed_data_parallel
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+distributed.finalize\_model\_grads
+----------------------------------
+
+Finalize model gradients for optimizer step across all used parallelism modes.
+Synchronizes the all-reduce / reduce-scatter of model gradients across DP replicas,
+all-reduces the layernorm gradients for sequence parallelism, embedding gradients
+across first and last pipeline stages (if not tied), and expert gradients for expert
+parallelism.
+
+.. automodule:: core.distributed.finalize_model_grads
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+
+Module contents
+---------------
+
+Contains functionality to synchronize gradients across different ranks before
+optimizer step.
+
+.. automodule:: core.distributed
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/api-guide/encoder_decoder_parallelism.rst b/docs/source/api-guide/encoder_decoder_parallelism.rst
new file mode 100644
index 0000000000..7cdff941de
--- /dev/null
+++ b/docs/source/api-guide/encoder_decoder_parallelism.rst
@@ -0,0 +1,54 @@
+encoder-decoder-parallelism package
+===================================
+
+Mcore (as of 0.9) supports heterogeneous parallelism for encoder-decoder models.
+In particular, the user is now able to specify the amount of tensor and pipeline parallelism and have it be
+distinct from that in the decoder.
+
+Submodules
+----------
+
+Encoder Pipeline Parallelism
+----------------------------
+
+Supported in: T5, LLaVa.
+
+The new argument for encoder parallelism is `--encoder-pipeline-model-parallel-size`. This argument is completely distinct
+from the usual argument that controls pipelining: `--pipeline-model-parallel-size`, which controls the amount of pipelining in the decoder
+in the context of encoder-decoder models.
+
+The total amount of pipelining in an encoder-decoder model is the sum of these two arguments. By default, the amount of
+encoder pipelining is 0, and the amount of decoder pipelining is 1, meaning that the encoder & decoder share the single pipeline rank.
+If `--pipeline-model-parallel-size` > 1,then the amount of encoder parallelism has to be specified and has to be greater than 0.
+This is because we are not able to share pipeline ranks between the encoder and decoder anymore.
+
+Encoder Tensor Parallelism
+--------------------------
+
+Supported in: LLaVa.
+
+Since we expect encoders to be much smaller than decoders, we also give users the ability to set a different amount of tensor
+parallelism than the decoder. This is achieved with the argument `--encoder-tensor-model-parallel-size`. To use this option, you must
+be using encoder pipeline parallelism (ie, `--encoder-pipeline-model-parallel-size` > 0).
+
+Unlike with encoder pipeline parallelism, which was unrestricted by the amount of decoder pipeline parallelism, we only allow encoders to have
+less than or the same amount of tensor parallelism as the decoder. The summary of how we do this is that within p2p_communication.py, we have
+to send the activations of one encoder rank to several decoder ranks; correspondingly, we have to add support for summing gradients from several
+(downstream) decoder ranks for the encoder rank. We have not seen a quantization-related degradation from summing these gradient tensors
+together yet; it could happen in very large models.
+
+
+Number of GPUs Required
+-----------------------
+
+The total amount of GPUs required to train a model when these options enabled is:
+
+dp * etp * epp * cp + dp * tp * pp * cp
+
+where:
+dp: amount of data parallelism (this is the same for the encoder & decoder)
+[e]tp: amount of tensor parallelism
+[e]pp: amount of pipeline parallelism
+cp: amount of context parallelism (as with dp, this is the same for the encoder & decoder)
+
+The default value of this argument is 0; in practice, we will use the amount of tensor parallelism in the decoder to construct the encoder.
diff --git a/docs/source/api-guide/fusions.rst b/docs/source/api-guide/fusions.rst
new file mode 100644
index 0000000000..22782ca84e
--- /dev/null
+++ b/docs/source/api-guide/fusions.rst
@@ -0,0 +1,65 @@
+fusions package
+===============
+
+This package provides modules that provide commonly fused
+operations. Fusing operations improves compute efficiency by
+increasing the amount of work done each time a tensor is read from
+memory. To perform the fusion, modules in this either rely on PyTorch
+functionality for doing just-in-time compilation
+(i.e. `torch.jit.script` in older PyTorch versions of `torch.compile`
+in recent versions), or call into custom kernels in external libraries
+such as Apex or TransformerEngine.
+
+Submodules
+----------
+
+fusions.fused\_bias\_dropout module
+-----------------------------------
+
+This module uses PyTorch JIT to fuse the bias add and dropout operations. Since dropout is not used during inference, different functions are used when in train mode and when in inference mode.
+
+.. automodule:: core.fusions.fused_bias_dropout
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+fusions.fused\_bias\_gelu module
+--------------------------------
+
+This module uses PyTorch JIT to fuse the bias add and GeLU nonlinearity operations.
+
+.. automodule:: core.fusions.fused_bias_gelu
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+fusions.fused\_layer\_norm module
+---------------------------------
+
+This module provides a wrapper around various fused LayerNorm implementation in Apex.
+
+.. automodule:: core.fusions.fused_layer_norm
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+fusions.fused\_softmax module
+-----------------------------
+
+This module provides wrappers around variations of Softmax in Apex.
+
+.. automodule:: core.fusions.fused_softmax
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+fusions.fused\_cross\_entropy\_loss module
+------------------------------------------
+
+This module uses PyTorch JIT to fuse the cross entropy loss calculation and batches communication calls.
+
+.. automodule:: core.fusions.fused_cross_entropy
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
diff --git a/docs/source/api-guide/index.rst b/docs/source/api-guide/index.rst
new file mode 100644
index 0000000000..dac785af04
--- /dev/null
+++ b/docs/source/api-guide/index.rst
@@ -0,0 +1,20 @@
+API Guide
+=========
+
+.. toctree::
+   :maxdepth: 4
+
+   models
+   tensor_parallel
+   context_parallel
+   pipeline_parallel
+   fusions
+   transformer
+   moe
+   dist_checkpointing
+   dist_optimizer
+   distributed
+   datasets
+   num_microbatches_calculator
+   optimizer_param_scheduler
+   encoder_decoder_parallelism
\ No newline at end of file
diff --git a/docs/source/api-guide/models.bert.rst b/docs/source/api-guide/models.bert.rst
new file mode 100644
index 0000000000..1b562ce72c
--- /dev/null
+++ b/docs/source/api-guide/models.bert.rst
@@ -0,0 +1,22 @@
+models.bert package
+===================
+Useful package for training bert and bert like encoder only models. It optionally comes with a binary head that can be used for classification tasks . 
+
+Submodules
+----------
+
+models.bert.bert\_model module
+------------------------------
+
+.. automodule:: core.models.bert.bert_model
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: core.models.bert
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/api-guide/models.gpt.rst b/docs/source/api-guide/models.gpt.rst
new file mode 100644
index 0000000000..31c4da6a9c
--- /dev/null
+++ b/docs/source/api-guide/models.gpt.rst
@@ -0,0 +1,22 @@
+models.gpt package
+==================
+This is the implementation of the popular GPT model. It supports several features like model parallelization (Tensor Parallel, Pipeline Parallel, Data Parallel) , mixture of experts, FP8 , Distributed optimizer etc. We are constantly adding new features. So be on the lookout or raise an issue if you want to have something added. 
+
+Submodules
+----------
+
+models.gpt.gpt\_model module
+----------------------------
+
+.. automodule:: core.models.gpt.gpt_model
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: core.models.gpt
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/api-guide/models.rst b/docs/source/api-guide/models.rst
new file mode 100644
index 0000000000..12c40e4f35
--- /dev/null
+++ b/docs/source/api-guide/models.rst
@@ -0,0 +1,21 @@
+models package
+==============
+This package contains most of the popular LLMs . Currently we have support for GPT, Bert, T5 and Retro . This is an ever growing list so keep an eye out. 
+
+Subpackages
+-----------
+
+.. toctree::
+   :maxdepth: 4
+
+   models.gpt
+   models.t5
+   models.bert
+
+Module contents
+---------------
+
+.. automodule:: core.models
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/api-guide/models.t5.rst b/docs/source/api-guide/models.t5.rst
new file mode 100644
index 0000000000..1cc3315682
--- /dev/null
+++ b/docs/source/api-guide/models.t5.rst
@@ -0,0 +1,21 @@
+models.t5 package
+=================
+
+Submodules
+----------
+
+models.t5.t5\_model module
+--------------------------
+
+.. automodule:: core.models.T5.t5_model
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: core.models.T5
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/api-guide/moe.rst b/docs/source/api-guide/moe.rst
new file mode 100644
index 0000000000..9afc01e080
--- /dev/null
+++ b/docs/source/api-guide/moe.rst
@@ -0,0 +1,4 @@
+Mixture of Experts package
+==========================
+
+.. mdinclude :: ../../../megatron/core/transformer/moe/README.md
diff --git a/docs/source/api-guide/num_microbatches_calculator.rst b/docs/source/api-guide/num_microbatches_calculator.rst
new file mode 100644
index 0000000000..4790b31749
--- /dev/null
+++ b/docs/source/api-guide/num_microbatches_calculator.rst
@@ -0,0 +1,12 @@
+Microbatches Calculator
+=======================
+This api is used to calculate the number of microbatches required to fit a given model on a given batch size.
+
+
+Module contents
+---------------
+
+.. automodule:: core.num_microbatches_calculator
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/api-guide/optimizer_param_scheduler.rst b/docs/source/api-guide/optimizer_param_scheduler.rst
new file mode 100644
index 0000000000..caf5d8abfb
--- /dev/null
+++ b/docs/source/api-guide/optimizer_param_scheduler.rst
@@ -0,0 +1,12 @@
+Optimizer Parameters Scheduler
+==============================
+This api is used to calculate the learning rate and weight decay for the optimizer.
+
+
+Module contents
+---------------
+
+.. automodule:: core.optimizer_param_scheduler
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/api-guide/pipeline_parallel.rst b/docs/source/api-guide/pipeline_parallel.rst
new file mode 100644
index 0000000000..5c67079a70
--- /dev/null
+++ b/docs/source/api-guide/pipeline_parallel.rst
@@ -0,0 +1,47 @@
+pipeline\_parallel package
+==========================
+
+This package contains implementations for two different pipeline parallelism
+schedules (one without interleaving and one with interleaving, see `Efficient
+Large-Scale Language Model Training on GPU Clusters Using Megatron-LM `_
+for details), and a default no-pipelining schedule. It also contains methods
+for the point-to-point communication that is needed between pipeline stages.
+
+Submodules
+----------
+
+pipeline\_parallel.p2p\_communication module
+--------------------------------------------
+
+Contains implementations for the various point-to-point communication needed
+(e.g., `recv_forward` and `recv_backward`) in the different pipeline parallelism
+schedules.
+
+.. automodule:: core.pipeline_parallel.p2p_communication
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+pipeline\_parallel.schedules module
+-----------------------------------
+
+Contains implementations for two pipeline parallelism schedules
+(`forward_backward_pipelining_with_interleaving`for pipeline parallelism with
+interleaving, `forward_backward_pipelining_without_interleaving` for pipeline
+parallelism without interleaving) and a default no-pipelining schedule
+(`forward_backward_no_pipelining`). `get_forward_backward_func` returns the right
+scheduling function to use based on the configuration being trained
+(e.g., if pipeline-parallel size is 1, use `forward_backward_no_pipelining`).
+
+.. automodule:: core.pipeline_parallel.schedules
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: core.pipeline_parallel
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/api-guide/tensor_parallel.rst b/docs/source/api-guide/tensor_parallel.rst
new file mode 100644
index 0000000000..d8ae9dea22
--- /dev/null
+++ b/docs/source/api-guide/tensor_parallel.rst
@@ -0,0 +1,67 @@
+tensor\_parallel package
+========================
+
+This package contains an implementation for tensor parallelism in transformer
+models (see `Megatron-LM: Training Multi-Billion Parameter Language Models
+Using Model Parallelism `_ and `Reducing
+Activation Recomputation in Large Transformer Models `_
+for details).
+
+Submodules
+----------
+
+tensor\_parallel.cross\_entropy module
+--------------------------------------
+
+.. automodule:: core.tensor_parallel.cross_entropy
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+tensor\_parallel.data module
+----------------------------
+
+.. automodule:: core.tensor_parallel.data
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+tensor\_parallel.layers module
+------------------------------
+
+.. automodule:: core.tensor_parallel.layers
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+tensor\_parallel.mappings module
+--------------------------------
+
+.. automodule:: core.tensor_parallel.mappings
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+tensor\_parallel.random module
+------------------------------
+
+.. automodule:: core.tensor_parallel.random
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+tensor\_parallel.utils module
+-----------------------------
+
+.. automodule:: core.tensor_parallel.utils
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: core.tensor_parallel
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/api-guide/transformer.rst b/docs/source/api-guide/transformer.rst
new file mode 100644
index 0000000000..6e2e894d54
--- /dev/null
+++ b/docs/source/api-guide/transformer.rst
@@ -0,0 +1,136 @@
+transformer package
+===================
+
+The `transformer` package provides a customizable and configurable
+implementation of the transformer model architecture. Each component
+of a transformer stack, from entire layers down to individual linear
+layers, can be customized by swapping in different PyTorch modules
+using the "spec" parameters (see `here
+`_). The
+configuration of the transformer (hidden size, number of layers,
+number of attention heads, etc.) is provided via a `TransformerConfig`
+object.
+
+Submodules
+----------
+
+transformer.attention module
+----------------------------
+
+This is the entire attention portion, either self or cross attention,
+of a transformer layer including the query, key, and value
+projections, a "core" attention calculation (e.g. dot product
+attention), and final output linear projection.
+
+.. automodule:: core.transformer.attention
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+transformer.dot\_product\_attention module
+------------------------------------------
+
+This is a PyTorch-only implementation of dot product attention. A more
+efficient implementation, like those provided by FlashAttention or
+CUDNN's FusedAttention, are typically used when training speed is
+important.
+
+.. automodule:: core.transformer.dot_product_attention
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+transformer.enums module
+------------------------
+
+.. automodule:: core.transformer.enums
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+transformer.identity\_op module
+-------------------------------
+
+This provides a pass-through module that can be used in specs to
+indicate that the operation should not be performed. For example, when
+using LayerNorm with the subsequent linear layer, an IdentityOp can be
+passed in as the LayerNorm module to use.
+
+.. automodule:: core.transformer.identity_op
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+transformer.mlp module
+----------------------
+
+This is the entire MLP portion of the transformer layer with an input
+projection, non-linearity, and output projection.
+
+.. automodule:: core.transformer.mlp
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+transformer.module module
+-------------------------
+
+This provides a common base class for all modules used in the
+transformer that contains some common functionality.
+
+.. automodule:: core.transformer.module
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+transformer.transformer\_block module
+-------------------------------------
+
+A block, or stack, of several transformer layers. The layers can all
+be the same or each can be unique.
+
+.. automodule:: core.transformer.transformer_block
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+transformer.transformer\_config module
+--------------------------------------
+
+This contains all of the configuration options for the
+transformer. Using a dataclass reduces code bloat by keeping all
+arguments together in a dataclass instead of passing several arguments
+through multiple layers of function calls.
+
+.. automodule:: core.transformer.transformer_config
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+transformer.transformer\_layer module
+-------------------------------------
+
+A single standard transformer layer including attention and MLP blocks.
+
+.. automodule:: core.transformer.transformer_layer
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+transformer.utils module
+------------------------
+
+Various utilities used in the transformer implementation.
+
+.. automodule:: core.transformer.utils
+   :members:
+   :undoc-members:
+   :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: core.transformer
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/docs/source/images/context_parallel/CP_overview.png b/docs/source/images/context_parallel/CP_overview.png
new file mode 100644
index 0000000000..38c55b371a
Binary files /dev/null and b/docs/source/images/context_parallel/CP_overview.png differ
diff --git a/docs/source/images/context_parallel/CP_results.png b/docs/source/images/context_parallel/CP_results.png
new file mode 100644
index 0000000000..e0415ce86e
Binary files /dev/null and b/docs/source/images/context_parallel/CP_results.png differ
diff --git a/docs/source/images/distrib_optimizer/data_flow.png b/docs/source/images/distrib_optimizer/data_flow.png
new file mode 100644
index 0000000000..01f5cfb2e7
Binary files /dev/null and b/docs/source/images/distrib_optimizer/data_flow.png differ
diff --git a/docs/source/images/distrib_optimizer/sharding_scheme.png b/docs/source/images/distrib_optimizer/sharding_scheme.png
new file mode 100644
index 0000000000..e48dd95024
Binary files /dev/null and b/docs/source/images/distrib_optimizer/sharding_scheme.png differ
diff --git a/docs/source/images/moe/token_drop.png b/docs/source/images/moe/token_drop.png
new file mode 100644
index 0000000000..1c335ee7aa
Binary files /dev/null and b/docs/source/images/moe/token_drop.png differ
diff --git a/docs/source/index.rst b/docs/source/index.rst
new file mode 100644
index 0000000000..f2a89b8ac7
--- /dev/null
+++ b/docs/source/index.rst
@@ -0,0 +1,23 @@
+.. Lumache documentation master file, created by
+   sphinx-quickstart on Tue Aug 15 13:44:10 2023.
+   You can adapt this file completely to your liking, but it should at least
+   contain the root `toctree` directive.
+
+Megatron Core User Guide
+===================================
+
+**Megatron Core** is a Python library that has the core components required to build your language models. 
+A reference implementation of Megatron Core can be found in  `NeMo `_ It offers a *simple* and
+*intuitive* API.
+
+.. toctree::
+   :maxdepth: 2
+   :caption: User Guide
+
+   user-guide/index
+
+.. toctree::
+   :maxdepth: 3
+   :caption: API Guide
+   
+   api-guide/index
diff --git a/docs/source/user-guide/index.rst b/docs/source/user-guide/index.rst
new file mode 100644
index 0000000000..0fb996a4f0
--- /dev/null
+++ b/docs/source/user-guide/index.rst
@@ -0,0 +1,4 @@
+User Guide 
+============
+
+.. mdinclude:: ../../../megatron/core/QuickStart.md
\ No newline at end of file
diff --git a/examples/detxoify_lm/README.md b/examples/academic_paper_scripts/detxoify_lm/README.md
similarity index 100%
rename from examples/detxoify_lm/README.md
rename to examples/academic_paper_scripts/detxoify_lm/README.md
diff --git a/examples/detxoify_lm/annotations/filter-selfgeneration.py b/examples/academic_paper_scripts/detxoify_lm/annotations/filter-selfgeneration.py
similarity index 100%
rename from examples/detxoify_lm/annotations/filter-selfgeneration.py
rename to examples/academic_paper_scripts/detxoify_lm/annotations/filter-selfgeneration.py
diff --git a/examples/detxoify_lm/annotations/perspective_api_annotate.py b/examples/academic_paper_scripts/detxoify_lm/annotations/perspective_api_annotate.py
similarity index 98%
rename from examples/detxoify_lm/annotations/perspective_api_annotate.py
rename to examples/academic_paper_scripts/detxoify_lm/annotations/perspective_api_annotate.py
index fd82c2a2ae..9736db099a 100644
--- a/examples/detxoify_lm/annotations/perspective_api_annotate.py
+++ b/examples/academic_paper_scripts/detxoify_lm/annotations/perspective_api_annotate.py
@@ -107,7 +107,7 @@ def get_score(line):
             except UnicodeDecodeError:
                 try:
                     decoded_text = encoded_text[:20476].decode('utf8')
-                except:
+                except Exception:
                     print("Error occurred")
                     data['score'] = None
                     return json.dumps(data)
@@ -138,7 +138,7 @@ def get_scores(lines):
                 except UnicodeDecodeError:
                     try:
                         decoded_text = encoded_text[:20476].decode('utf8')
-                    except:
+                    except Exception:
                         print("Error occurred")
                         data['score'] = None
                         all_data.append(json.dumps(data))
diff --git a/examples/detxoify_lm/annotations/preprocess.sh b/examples/academic_paper_scripts/detxoify_lm/annotations/preprocess.sh
similarity index 100%
rename from examples/detxoify_lm/annotations/preprocess.sh
rename to examples/academic_paper_scripts/detxoify_lm/annotations/preprocess.sh
diff --git a/examples/detxoify_lm/finetune_gpt.py b/examples/academic_paper_scripts/detxoify_lm/finetune_gpt.py
similarity index 88%
rename from examples/detxoify_lm/finetune_gpt.py
rename to examples/academic_paper_scripts/detxoify_lm/finetune_gpt.py
index f1bbba5bda..6a3696d388 100644
--- a/examples/detxoify_lm/finetune_gpt.py
+++ b/examples/academic_paper_scripts/detxoify_lm/finetune_gpt.py
@@ -10,19 +10,20 @@
 import sys
 sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                              os.path.pardir, os.path.pardir)))
-from megatron import get_args
-from megatron import get_timers
-from megatron import get_tokenizer
-from megatron import print_rank_0
+from megatron.training import get_args
+from megatron.training import get_timers
+from megatron.training import get_tokenizer
+from megatron.training import print_rank_0
 from megatron.core import mpu
 from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
 from megatron.core.datasets.blended_megatron_dataset_config import GPTDatasetConfig
 from megatron.core.datasets.gpt_dataset import GPTDataset
-from megatron.model import GPTModel
+from megatron.core.datasets.utils import get_blend_from_list
+from megatron.legacy.model import GPTModel
 from megatron.core.enums import ModelType
 from megatron.training import pretrain
-from megatron.utils import get_ltor_masks_and_position_ids
-from megatron.utils import average_losses_across_data_parallel_group
+from megatron.training.utils import get_ltor_masks_and_position_ids
+from megatron.training.utils import average_losses_across_data_parallel_group
 
 def model_provider(pre_process=True, post_process=True):
     """Build the model."""
@@ -105,8 +106,9 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
     train_ds, _, test_ds = BlendedMegatronDatasetBuilder(
         GPTDataset,
         train_val_test_num_samples,
+        lambda: True,
         GPTDatasetConfig(
-            blend=args.data_path,
+            blend=get_blend_from_list(args.data_path),
             split=args.split,
             random_seed=args.seed,
             sequence_length=args.seq_length,
@@ -119,8 +121,9 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
     _, valid_ds, _ = BlendedMegatronDatasetBuilder(
         GPTDataset,
         train_val_test_num_samples,
+        lambda: True,
         GPTDatasetConfig(
-            blend=args.data_path2,
+            blend=get_blend_from_list(args.data_path2),
             split="98,2,0",
             random_seed=1234,
             sequence_length=2048,
diff --git a/examples/detxoify_lm/finetune_gpt_distributed-1.3b.sh b/examples/academic_paper_scripts/detxoify_lm/finetune_gpt_distributed-1.3b.sh
similarity index 100%
rename from examples/detxoify_lm/finetune_gpt_distributed-1.3b.sh
rename to examples/academic_paper_scripts/detxoify_lm/finetune_gpt_distributed-1.3b.sh
diff --git a/examples/detxoify_lm/generate-1.3b.sh b/examples/academic_paper_scripts/detxoify_lm/generate-1.3b.sh
similarity index 100%
rename from examples/detxoify_lm/generate-1.3b.sh
rename to examples/academic_paper_scripts/detxoify_lm/generate-1.3b.sh
diff --git a/examples/detxoify_lm/generate_samples_gpt.py b/examples/academic_paper_scripts/detxoify_lm/generate_samples_gpt.py
similarity index 68%
rename from examples/detxoify_lm/generate_samples_gpt.py
rename to examples/academic_paper_scripts/detxoify_lm/generate_samples_gpt.py
index 47e1590ea5..895a45d024 100644
--- a/examples/detxoify_lm/generate_samples_gpt.py
+++ b/examples/academic_paper_scripts/detxoify_lm/generate_samples_gpt.py
@@ -9,23 +9,84 @@
 sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                              os.path.pardir, os.path.pardir)))
 import torch
-from megatron import get_args
-from megatron import get_tokenizer
-from megatron import print_rank_0
-from megatron.checkpointing import load_checkpoint
+from megatron.training import get_args
+from megatron.training import get_tokenizer
+from megatron.training import print_rank_0
+from megatron.training.checkpointing import load_checkpoint
 from megatron.core import mpu
-from megatron.initialize import initialize_megatron
-from megatron.model import GPTModel
+from megatron.training.initialize import initialize_megatron
+from megatron.legacy.model import GPTModel
 from megatron.training import get_model
-from megatron.text_generation import generate_and_post_process
+from megatron.inference.text_generation import generate_and_post_process
+from megatron.training.arguments import core_transformer_config_from_args
+from megatron.core.models.gpt import GPTModel
+from typing import Union
+import megatron.legacy.model
+from megatron.core.transformer.spec_utils import import_module
+from megatron.training.arguments import core_transformer_config_from_args
+from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_local_spec
 
+def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
+    """Builds the model.
 
-def model_provider(pre_process=True, post_process=True):
-    """Build the model."""
+    If you set the use_legacy_models to True, it will return the legacy GPT model and if not the core GPT model.
+
+    Args:
+        pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
+        post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
+
+
+    Returns:
+        Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
+    """
+    args = get_args()
 
     print_rank_0('building GPT model ...')
-    model = GPTModel(num_tokentypes=0, parallel_output=False,
-                     pre_process=pre_process, post_process=post_process)
+    config = core_transformer_config_from_args(args)
+
+    if args.use_legacy_models:
+        model = megatron.legacy.model.GPTModel(
+            config,
+            num_tokentypes=0,
+            parallel_output=False,
+            pre_process=pre_process,
+            post_process=post_process
+        )
+    else:
+        if args.spec is None:
+            if args.transformer_impl == 'local':
+                transformer_layer_spec = get_gpt_layer_local_spec(
+                    num_experts=args.num_experts,
+                    moe_grouped_gemm=args.moe_grouped_gemm
+                )
+            elif args.transformer_impl == 'transformer_engine':
+                transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
+                    num_experts=args.num_experts,
+                    moe_grouped_gemm=args.moe_grouped_gemm
+                )
+            else:
+                raise ValueError(f"Invalid transformer_impl {args.transformer_impl}")
+        elif args.spec[0] == 'local':
+            transformer_layer_spec = get_gpt_layer_local_spec(
+                num_experts=args.num_experts,
+                moe_grouped_gemm=args.moe_grouped_gemm
+            )
+        else:
+            transformer_layer_spec = import_module(args.spec)
+
+        model = GPTModel(
+            config=config,
+            transformer_layer_spec=transformer_layer_spec,
+            vocab_size=args.padded_vocab_size,
+            max_sequence_length=args.max_position_embeddings,
+            pre_process=pre_process,
+            post_process=post_process,
+            fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
+            parallel_output=False,
+            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
+            position_embedding_type=args.position_embedding_type,
+            rotary_percent=args.rotary_percent
+        )
 
     return model
 
diff --git a/examples/detxoify_lm/perspective_api.py b/examples/academic_paper_scripts/detxoify_lm/perspective_api.py
similarity index 100%
rename from examples/detxoify_lm/perspective_api.py
rename to examples/academic_paper_scripts/detxoify_lm/perspective_api.py
diff --git a/examples/detxoify_lm/self_generation/selfgenerate-1.3b-unconditional.sh b/examples/academic_paper_scripts/detxoify_lm/self_generation/selfgenerate-1.3b-unconditional.sh
similarity index 100%
rename from examples/detxoify_lm/self_generation/selfgenerate-1.3b-unconditional.sh
rename to examples/academic_paper_scripts/detxoify_lm/self_generation/selfgenerate-1.3b-unconditional.sh
diff --git a/examples/msdp/README.md b/examples/academic_paper_scripts/msdp/README.md
similarity index 100%
rename from examples/msdp/README.md
rename to examples/academic_paper_scripts/msdp/README.md
diff --git a/examples/msdp/data_processing.sh b/examples/academic_paper_scripts/msdp/data_processing.sh
similarity index 100%
rename from examples/msdp/data_processing.sh
rename to examples/academic_paper_scripts/msdp/data_processing.sh
diff --git a/examples/msdp/eval_knwl_generation.sh b/examples/academic_paper_scripts/msdp/eval_knwl_generation.sh
similarity index 100%
rename from examples/msdp/eval_knwl_generation.sh
rename to examples/academic_paper_scripts/msdp/eval_knwl_generation.sh
diff --git a/examples/msdp/eval_resp_generation.sh b/examples/academic_paper_scripts/msdp/eval_resp_generation.sh
similarity index 100%
rename from examples/msdp/eval_resp_generation.sh
rename to examples/academic_paper_scripts/msdp/eval_resp_generation.sh
diff --git a/examples/msdp/prep_resp_gen.sh b/examples/academic_paper_scripts/msdp/prep_resp_gen.sh
similarity index 100%
rename from examples/msdp/prep_resp_gen.sh
rename to examples/academic_paper_scripts/msdp/prep_resp_gen.sh
diff --git a/examples/msdp/prompt_knwl_gen.sh b/examples/academic_paper_scripts/msdp/prompt_knwl_gen.sh
similarity index 100%
rename from examples/msdp/prompt_knwl_gen.sh
rename to examples/academic_paper_scripts/msdp/prompt_knwl_gen.sh
diff --git a/examples/msdp/prompt_resp_gen.sh b/examples/academic_paper_scripts/msdp/prompt_resp_gen.sh
similarity index 100%
rename from examples/msdp/prompt_resp_gen.sh
rename to examples/academic_paper_scripts/msdp/prompt_resp_gen.sh
diff --git a/examples/sc21/CONFIG.sh b/examples/academic_paper_scripts/sc21/CONFIG.sh
similarity index 100%
rename from examples/sc21/CONFIG.sh
rename to examples/academic_paper_scripts/sc21/CONFIG.sh
diff --git a/examples/sc21/README.md b/examples/academic_paper_scripts/sc21/README.md
similarity index 100%
rename from examples/sc21/README.md
rename to examples/academic_paper_scripts/sc21/README.md
diff --git a/examples/sc21/SBATCH.sh b/examples/academic_paper_scripts/sc21/SBATCH.sh
similarity index 100%
rename from examples/sc21/SBATCH.sh
rename to examples/academic_paper_scripts/sc21/SBATCH.sh
diff --git a/examples/sc21/SRUN.sh b/examples/academic_paper_scripts/sc21/SRUN.sh
similarity index 100%
rename from examples/sc21/SRUN.sh
rename to examples/academic_paper_scripts/sc21/SRUN.sh
diff --git a/examples/sc21/run_figure_11.sh b/examples/academic_paper_scripts/sc21/run_figure_11.sh
similarity index 100%
rename from examples/sc21/run_figure_11.sh
rename to examples/academic_paper_scripts/sc21/run_figure_11.sh
diff --git a/examples/sc21/run_figure_12.sh b/examples/academic_paper_scripts/sc21/run_figure_12.sh
similarity index 100%
rename from examples/sc21/run_figure_12.sh
rename to examples/academic_paper_scripts/sc21/run_figure_12.sh
diff --git a/examples/sc21/run_figure_13.sh b/examples/academic_paper_scripts/sc21/run_figure_13.sh
similarity index 100%
rename from examples/sc21/run_figure_13.sh
rename to examples/academic_paper_scripts/sc21/run_figure_13.sh
diff --git a/examples/sc21/run_figure_14.sh b/examples/academic_paper_scripts/sc21/run_figure_14.sh
similarity index 100%
rename from examples/sc21/run_figure_14.sh
rename to examples/academic_paper_scripts/sc21/run_figure_14.sh
diff --git a/examples/sc21/run_figure_15.sh b/examples/academic_paper_scripts/sc21/run_figure_15.sh
similarity index 100%
rename from examples/sc21/run_figure_15.sh
rename to examples/academic_paper_scripts/sc21/run_figure_15.sh
diff --git a/examples/sc21/run_figure_16.sh b/examples/academic_paper_scripts/sc21/run_figure_16.sh
similarity index 100%
rename from examples/sc21/run_figure_16.sh
rename to examples/academic_paper_scripts/sc21/run_figure_16.sh
diff --git a/examples/sc21/run_figure_17.sh b/examples/academic_paper_scripts/sc21/run_figure_17.sh
similarity index 100%
rename from examples/sc21/run_figure_17.sh
rename to examples/academic_paper_scripts/sc21/run_figure_17.sh
diff --git a/examples/sc21/run_figure_18.sh b/examples/academic_paper_scripts/sc21/run_figure_18.sh
similarity index 100%
rename from examples/sc21/run_figure_18.sh
rename to examples/academic_paper_scripts/sc21/run_figure_18.sh
diff --git a/examples/sc21/run_table_1.sh b/examples/academic_paper_scripts/sc21/run_table_1.sh
similarity index 100%
rename from examples/sc21/run_table_1.sh
rename to examples/academic_paper_scripts/sc21/run_table_1.sh
diff --git a/examples/bert/README.md b/examples/bert/README.md
new file mode 100644
index 0000000000..6c1fe95bf0
--- /dev/null
+++ b/examples/bert/README.md
@@ -0,0 +1,53 @@
+# BERT MODEL
+
+## Table of contents
+- [1. Training Setup](#1-training-setup)
+- [2. Configurations](#2-configurations)
+
+## 1. Training setup
+
+
+To run the model using a docker container run it as follows
+```
+PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.01-py3
+CHECKPOINT_PATH="" #
+TENSORBOARD_LOGS_PATH=""#
+VOCAB_FILE="" #//bert-vocab.txt
+DATA_PATH="" #_text_document
+
+docker run \
+  --gpus=all \
+  --ipc=host \
+  --workdir /workspace/megatron-lm \
+  -v /path/to/data:/path/to/data \
+  -v /path/to/megatron-lm:/workspace/megatron-lm \
+  megatron-lm nvcr.io/nvidia/pytorch:24.01-py3 \
+  bash examples/bert/train_bert_340m_distributed.sh $CHECKPOINT_PATH $TENSORBOARD_LOGS_PATH $VOCAB_FILE $DATA_PATH "
+
+```
+NOTE: Depending on the environment you are running it the above command might like slightly different.
+
+
+## 2. Configurations
+
+The example in this folder shows you how to run 340m large model. There are other configs you could run as well
+
+### 4B
+```
+       --num-layers 48 \
+       --hidden-size 2560 \
+       --num-attention-heads 32 \
+       --tensor-model-parallel-size 1 \
+       --pipeline-model-parallel-size 1 \
+
+```
+
+### 20B
+```
+       --num-layers 48 \
+       --hidden-size 6144 \
+       --num-attention-heads 96 \
+       --tensor-model-parallel-size 4 \
+       --pipeline-model-parallel-size 4 \
+
+```
\ No newline at end of file
diff --git a/examples/bert/train_bert_340m_distributed.sh b/examples/bert/train_bert_340m_distributed.sh
new file mode 100644
index 0000000000..f0d9c87c8b
--- /dev/null
+++ b/examples/bert/train_bert_340m_distributed.sh
@@ -0,0 +1,79 @@
+#!/bin/bash
+
+# Runs the "340M" parameter model (Bert - Large)
+
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+
+GPUS_PER_NODE=8
+# Change for multinode config
+MASTER_ADDR=localhost
+MASTER_PORT=6000
+NUM_NODES=1
+NODE_RANK=0
+WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))
+
+CHECKPOINT_PATH=$1 #
+TENSORBOARD_LOGS_PATH=$2 #
+VOCAB_FILE=$3 #/bert-vocab.json
+DATA_PATH=$4 #_text_document
+
+DISTRIBUTED_ARGS=(
+    --nproc_per_node $GPUS_PER_NODE 
+    --nnodes $NUM_NODES 
+    --master_addr $MASTER_ADDR 
+    --master_port $MASTER_PORT
+)
+
+BERT_MODEL_ARGS=(
+    --num-layers 24 
+    --hidden-size 1024 
+    --num-attention-heads 16 
+    --seq-length 512 
+    --max-position-embeddings 512 
+    --attention-backend auto # Can use (flash/fused/unfused/local)
+)
+
+TRAINING_ARGS=(
+    --micro-batch-size 4 
+    --global-batch-size 32 
+    --train-iters 1000000 
+    --weight-decay 1e-2 
+    --clip-grad 1.0 
+    --fp16
+    --lr 0.0001
+    --lr-decay-iters 990000 
+    --lr-decay-style linear 
+    --min-lr 1.0e-5 
+    --weight-decay 1e-2 
+    --lr-warmup-fraction .01 
+    --clip-grad 1.0 
+)
+
+MODEL_PARALLEL_ARGS=(
+	--tensor-model-parallel-size 8 
+	--pipeline-model-parallel-size 16 
+)
+
+DATA_ARGS=(
+    --data-path $DATA_PATH 
+    --vocab-file $VOCAB_FILE 
+    --split 949,50,1
+)
+
+EVAL_AND_LOGGING_ARGS=(
+    --log-interval 100
+    --save-interval 10000 
+    --eval-interval 1000 
+    --save $CHECKPOINT_PATH 
+    --load $CHECKPOINT_PATH 
+    --eval-iters 10
+    --tensorboard-dir $TENSORBOARD_LOGS_PATH 
+)
+
+torchrun ${DISTRIBUTED_ARGS[@]} pretrain_bert.py \
+    ${BERT_MODEL_ARGS[@]} \
+    ${TRAINING_ARGS[@]} \
+    ${MODEL_PARALLEL_ARGS[@]} \
+    ${DATA_ARGS[@]} \
+    ${EVAL_AND_LOGGING_ARGS[@]}
+    
\ No newline at end of file
diff --git a/examples/evaluate_retriever_nq.sh b/examples/evaluate_retriever_nq.sh
deleted file mode 100644
index a579b5fd94..0000000000
--- a/examples/evaluate_retriever_nq.sh
+++ /dev/null
@@ -1,37 +0,0 @@
-#!/bin/bash
-
-# Evaluate natural question test data given Wikipedia embeddings and pretrained
-# ICT model or a finetuned model for Natural Question task
-
-# Datasets can be downloaded from the following link:
-# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
-
-EVIDENCE_DATA_DIR=
-EMBEDDING_PATH=
-CHECKPOINT_PATH=
-
-QA_FILE=
-
-python tasks/main.py \
-    --task RETRIEVER-EVAL \
-    --tokenizer-type BertWordPieceLowerCase \
-    --num-layers 12 \
-    --hidden-size 768 \
-    --num-attention-heads 12 \
-    --tensor-model-parallel-size 1 \
-    --micro-batch-size 128 \
-    --seq-length 512 \
-    --max-position-embeddings 512 \
-    --load ${CHECKPOINT_PATH} \
-    --evidence-data-path ${EVIDENCE_DATA_DIR} \
-    --embedding-path ${EMBEDDING_PATH} \
-    --retriever-seq-length 256 \
-    --vocab-file  bert-vocab.txt\
-    --qa-data-test ${QA_FILE} \
-    --faiss-use-gpu \
-    --retriever-report-topk-accuracies 1 5 20 100 \
-    --fp16 \
-    --indexer-log-interval 1000 \
-    --indexer-batch-size 128
-
-
diff --git a/examples/evaluate_zeroshot_gpt.sh b/examples/evaluate_zeroshot_gpt.sh
deleted file mode 100755
index 2cc1c5a760..0000000000
--- a/examples/evaluate_zeroshot_gpt.sh
+++ /dev/null
@@ -1,37 +0,0 @@
-#!/bin/bash
-
-WORLD_SIZE=8
-
-DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
-                  --nnodes 1 \
-                  --node_rank 0 \
-                  --master_addr localhost \
-                  --master_port 6000"
-
-TASK="LAMBADA"
-
-VALID_DATA=
-VOCAB_FILE=gpt2-vocab.json
-MERGE_FILE=gpt2-merges.txt
-CHECKPOINT=checkpoints/gpt2_345m
-
-
-python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
-               --task $TASK \
-               --valid-data $VALID_DATA \
-               --tokenizer-type GPT2BPETokenizer \
-               --strict-lambada \
-               --vocab-file $VOCAB_FILE \
-               --merge-file $MERGE_FILE \
-               --load $CHECKPOINT \
-               --tensor-model-parallel-size 1 \
-               --num-layers 24 \
-               --hidden-size 1024 \
-               --num-attention-heads 16 \
-               --batch-size 8 \
-               --seq-length 1024 \
-               --max-position-embeddings 1024 \
-               --log-interval 10 \
-               --fp16 \
-               --no-load-optim \
-               --no-load-rng
diff --git a/examples/export/README.md b/examples/export/README.md
new file mode 100644
index 0000000000..ddb8216f94
--- /dev/null
+++ b/examples/export/README.md
@@ -0,0 +1,10 @@
+# Megatron Core Export
+
+This module is used to export megatron core models to different inference frameworks. 
+Currently we support TRTLLM export . In the future we will be adding support for VLLM etc. 
+
+## PTQ AND EXPORT
+Follow the instructions in [ptq_and_trtllm_export](./ptq_and_trtllm_export) to do post training quantization, followed by an export to TRTLLM format. 
+
+# TRTLLM EXPORT
+Follow the instructions in [trtllm_export](./trtllm_export/) to do export to TRTLLM checkpoint format alone.
\ No newline at end of file
diff --git a/examples/export/knowledge_distillation/pretrain_gpt_modelopt.py b/examples/export/knowledge_distillation/pretrain_gpt_modelopt.py
new file mode 100644
index 0000000000..65d0727d8c
--- /dev/null
+++ b/examples/export/knowledge_distillation/pretrain_gpt_modelopt.py
@@ -0,0 +1,136 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+
+"""Pretrain GPT."""
+import os
+import sys
+from functools import partial
+
+# This file isn't located in project root, but to import, it should pretend to be.
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
+
+from megatron.core import mpu
+from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
+from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
+from megatron.core.datasets.utils import get_blend_from_list
+from megatron.core.enums import ModelType
+from megatron.core.models.gpt import GPTModel
+from megatron.core.utils import StragglerDetector
+from megatron.inference.arguments import add_modelopt_args
+from megatron.inference.gpt import loss_func, model_provider
+from megatron.training import get_args, get_timers, get_tokenizer, pretrain
+from megatron.training.utils import (
+    get_batch_on_this_cp_rank,
+    get_batch_on_this_tp_rank,
+    print_rank_0,
+)
+
+stimer = StragglerDetector()
+
+
+def get_batch(data_iterator):
+    """Generate a batch."""
+
+    # TODO: this is pretty hacky, find a better way
+    if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
+        return None, None, None, None, None
+
+    # get batches based on the TP rank you are on
+    batch = get_batch_on_this_tp_rank(data_iterator)
+
+    # slice batch along sequence dimension for context parallelism
+    batch = get_batch_on_this_cp_rank(batch)
+
+    return batch.values()
+
+
+def forward_step(data_iterator, model: GPTModel):
+    """Forward training step.
+
+    Args:
+        data_iterator : Input data iterator
+        model (GPTModel): The GPT Model
+    """
+    timers = get_timers()
+
+    # Get the batch.
+    timers('batch-generator', log_level=2).start()
+    global stimer
+    with stimer(bdata=True):
+        tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
+    timers('batch-generator').stop()
+
+    with stimer:
+        output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
+
+    # [ModelOpt]: model is needed to access ModelOpt distillation losses
+    return output_tensor, partial(loss_func, loss_mask, model)
+
+
+def is_dataset_built_on_rank():
+    return (
+        mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
+    ) and mpu.get_tensor_model_parallel_rank() == 0
+
+
+def core_gpt_dataset_config_from_args(args):
+    tokenizer = get_tokenizer()
+
+    return GPTDatasetConfig(
+        random_seed=args.seed,
+        sequence_length=args.seq_length,
+        blend=get_blend_from_list(args.data_path),
+        blend_per_split=[
+            get_blend_from_list(args.train_data_path),
+            get_blend_from_list(args.valid_data_path),
+            get_blend_from_list(args.test_data_path),
+        ],
+        split=args.split,
+        num_dataset_builder_threads=args.num_dataset_builder_threads,
+        path_to_cache=args.data_cache_path,
+        mmap_bin_files=args.mmap_bin_files,
+        tokenizer=tokenizer,
+        reset_position_ids=args.reset_position_ids,
+        reset_attention_mask=args.reset_attention_mask,
+        eod_mask_loss=args.eod_mask_loss,
+        create_attention_mask=args.create_attention_mask_in_dataloader,
+    )
+
+
+def train_valid_test_datasets_provider(train_val_test_num_samples):
+    """Build the train test and validation datasets.
+
+    Args:
+        train_val_test_num_samples : A list containing the number of samples in train test and validation.
+    """
+    args = get_args()
+
+    config = core_gpt_dataset_config_from_args(args)
+
+    if args.mock_data:
+        dataset_type = MockGPTDataset
+    else:
+        dataset_type = GPTDataset
+
+    print_rank_0("> building train, validation, and test datasets for GPT ...")
+
+    train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
+        dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config
+    ).build()
+
+    print_rank_0("> finished creating GPT datasets ...")
+
+    return train_ds, valid_ds, test_ds
+
+
+if __name__ == "__main__":
+    # Temporary for transition to core datasets
+    train_valid_test_datasets_provider.is_distributed = True
+
+    pretrain(
+        train_valid_test_datasets_provider,
+        model_provider,
+        ModelType.encoder_or_decoder,
+        forward_step,
+        args_defaults={"tokenizer_type": "GPT2BPETokenizer"},
+        extra_args_provider=add_modelopt_args,
+    )
diff --git a/examples/export/ptq_and_trtllm_export/README.md b/examples/export/ptq_and_trtllm_export/README.md
new file mode 100644
index 0000000000..2605910869
--- /dev/null
+++ b/examples/export/ptq_and_trtllm_export/README.md
@@ -0,0 +1,295 @@
+# Megatron Model Optimization and Deployment
+
+## Installation
+We recommend that users follow TensorRT-LLM's official installation guide to build it from source
+and proceed with a containerized environment (`docker.io/tensorrt_llm/release:latest`):
+
+```sh
+git clone https://github.com/NVIDIA/TensorRT-LLM.git
+cd TensorRT-LLM
+git checkout v0.10.0
+make -C docker release_build
+```
+
+> **TROUBLE SHOOTING:** rather than copying each folder separately in `docker/Dockerfile.multi`,
+> you may need to copy the entire dir as `COPY ./ /src/tensorrt_llm` since a `git submodule` is
+> called later which requires `.git` to continue.
+
+Once the container is built, install `nvidia-modelopt` and additional dependencies for sharded checkpoint support:
+```sh
+pip install "nvidia-modelopt[all]~=0.13.0" --extra-index-url https://pypi.nvidia.com
+pip install zarr tensorstore!=0.1.46
+```
+TensorRT-LLM quantization functionalities are currently packaged in `nvidia-modelopt`.
+You can find more documentation about `nvidia-modelopt` [here](https://nvidia.github.io/TensorRT-Model-Optimizer/).
+
+## Support Matrix
+
+The following matrix shows the current support for the PTQ + TensorRT-LLM export flow.
+
+| model                       | fp16 | int8_sq | fp8 | int4_awq |
+|-----------------------------|------|---------| ----| -------- |
+| nextllm-2b                  | x    | x       |   x |          |
+| nemotron3-8b                | x    |         |   x |          |
+| nemotron3-15b               | x    |         |   x |          |
+| llama2-text-7b              | x    | x       |   x |      TP2 |
+| llama2-chat-70b             | x    | x       |   x |      TP4 |
+
+Our PTQ + TensorRT-LLM flow has native support on MCore `GPTModel` with a mixed layer spec (native ParallelLinear
+and Transformer-Engine Norm (`TENorm`). Note that this is not the default mcore gpt spec. You can still load the
+following checkpoint formats with some remedy:
+
+| GPTModel                          | sharded |                        remedy arguments     |
+|-----------------------------------|---------|---------------------------------------------|
+| megatron.legacy.model             |         | `--export-legacy-megatron` |
+| TE-Fused (default mcore gpt spec) |         | `--export-te-mcore-model`       |
+| TE-Fused (default mcore gpt spec) |       x |                                             |
+
+> **TROUBLE SHOOTING:** If you are trying to load an unpacked `.nemo` sharded checkpoint, then typically you will
+> need to adding `additional_sharded_prefix="model."` to `modelopt_load_checkpoint()` since NeMo has an additional
+> `model.` wrapper on top of the `GPTModel`.
+
+> **NOTE:** flag `--export-legacy-megatron` may not work on all legacy checkpoint versions.
+
+## Examples
+
+> **NOTE:** we only provide a simple text generation script to test the generated TensorRT-LLM engines. For
+> a production-level API server or enterprise support, see [NeMo](https://github.com/NVIDIA/NeMo) and TensorRT-LLM's
+> backend for [NVIDIA Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server).
+
+### Minitron-8B FP8 Quantization and TensorRT-LLM Deployment
+First download the nemotron checkpoint from https://huggingface.co/nvidia/Minitron-8B-Base, extract the
+sharded checkpoint from the `.nemo` tarbal and fix the tokenizer file name.
+
+> **NOTE:** The following cloning method uses `ssh`, and assume you have registered the `ssh-key` in Hugging Face.
+> If you are want to clone with `https`, then `git clone https://huggingface.co/nvidia/Minitron-8B-Base` with an access token.
+
+```sh
+git lfs install
+git clone git@hf.co:nvidia/Minitron-8B-Base
+cd Minitron-8B-Base/nemo
+tar -xvf minitron-8b-base.nemo
+cd ../..
+```
+
+Now launch the PTQ + TensorRT-LLM export script,
+```sh
+bash examples/export/ptq_and_trtllm_export/ptq_trtllm_minitron_8b ./Minitron-8B-Base None
+```
+By default, `cnn_dailymail` is used for calibration. The `GPTModel` will have quantizers for simulating the
+quantization effect. The checkpoint will be saved optionally (with quantizers as additional states) and can
+be restored for further evaluation or quantization-aware training. TensorRT-LLM checkpoint and engine are exported to `/tmp/trtllm_ckpt` and
+built in `/tmp/trtllm_engine` by default.
+
+The script expects `${CHECKPOINT_DIR}` (`./Minitron-8B-Base/nemo`) to have the following structure:
+
+> **NOTE:** The .nemo checkpoint after extraction (including examples below) should all have the following strucure.
+
+```
+├── model_weights
+│   ├── common.pt
+│   ...
+│
+├── model_config.yaml
+│...
+```
+
+> **NOTE:** The script is using `TP=8`. Change `$TP` in the script if your checkpoint has a different tensor
+> model parallelism.
+
+Then build TensorRT engine and run text generation example using the newly built TensorRT engine
+
+```sh
+export trtllm_options=" \
+    --checkpoint_dir /tmp/trtllm_ckpt \
+    --output_dir /tmp/trtllm_engine \
+    --max_input_len 2048 \
+    --max_seq_len 512 \
+    --max_batch_size 8 "
+
+trtllm-build ${trtllm_options}
+
+python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer nvidia/Minitron-8B-Base
+```
+
+### mistral-12B FP8 Quantization and TensorRT-LLM Deployment
+First download the nemotron checkpoint from https://huggingface.co/nvidia/Mistral-NeMo-12B-Base, extract the
+sharded checkpoint from the `.nemo` tarbal.
+
+> **NOTE:** The following cloning method uses `ssh`, and assume you have registered the `ssh-key` in Hugging Face.
+> If you are want to clone with `https`, then `git clone https://huggingface.co/nvidia/Mistral-NeMo-12B-Base` with an access token.
+
+```sh
+git lfs install
+git clone git@hf.co:nvidia/Mistral-NeMo-12B-Base
+cd Mistral-NeMo-12B-Base
+tar -xvf Mistral-NeMo-12B-Base.nemo
+cd ..
+```
+
+Then log in to huggingface so that you can access to model
+
+> **NOTE:** You need a token generated from huggingface.co/settings/tokens and access to mistralai/Mistral-Nemo-Base-2407 on huggingface
+
+```sh
+pip install -U "huggingface_hub[cli]"
+huggingface-cli login
+```
+
+Now launch the PTQ + TensorRT-LLM checkpoint export script,
+
+```sh
+bash examples/export/ptq_and_trtllm_export/ptq_trtllm_mistral_12b.sh ./Mistral-NeMo-12B-Base None
+```
+
+Then build TensorRT engine and run text generation example using the newly built TensorRT engine
+
+```sh
+export trtllm_options=" \
+    --checkpoint_dir /tmp/trtllm_ckpt \
+    --output_dir /tmp/trtllm_engine \
+    --max_input_len 2048 \
+    --max_seq_len 512 \
+    --max_batch_size 8 "
+
+trtllm-build ${trtllm_options}
+
+python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer mistralai/Mistral-Nemo-Base-2407
+```
+
+
+### llama2-text-7b INT8 SmoothQuant and TensorRT-LLM Deployment
+> **NOTE:** Due to the LICENSE issue, we do not provide a MCore checkpoint to download. Users can follow
+> the instruction in `docs/llama2.md` to convert the checkpoint to megatron legacy `GPTModel` format and
+> use `--export-legacy-megatron` flag which will remap the checkpoint to the MCore `GPTModel` spec
+> that we support.
+
+```sh
+bash examples/export/ptq_and_trtllm_export/ptq_trtllm_llama_7b.sh ${CHECKPOINT_DIR}
+```
+
+The script expect `${CHECKPOINT_DIR}` to have the following structure:
+```
+├── hf
+│   ├── tokenizer.config
+│   ├── tokenizer.model
+│   ...
+│
+├── iter_0000001
+│   ├── mp_rank_00
+│   ...
+│
+├── latest_checkpointed_iteration.txt
+```
+In short, other than the converted llama megatron checkpoint, also put the Hugging Face checkpoint inside as
+the source of the tokenizer.
+
+Then build TensorRT engine and run text generation example using the newly built TensorRT engine
+
+```sh
+export trtllm_options=" \
+    --checkpoint_dir /tmp/trtllm_ckpt \
+    --output_dir /tmp/trtllm_engine \
+    --max_input_len 2048 \
+    --max_seq_len 512 \
+    --max_batch_size 8 "
+
+trtllm-build ${trtllm_options}
+
+python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer meta-llama/Llama-2-7b
+```
+
+### llama3-8b / llama3.1-8b INT8 SmoothQuant and TensorRT-LLM Deployment
+> **NOTE:** For llama3.1, the missing rope_scaling parameter will be fixed in modelopt-0.19 and trtllm-0.13.
+
+> **NOTE:** There are two ways to acquire the checkpoint. Users can follow
+> the instruction in `docs/llama2.md` to convert the checkpoint to megatron legacy `GPTModel` format and
+> use `--export-legacy-megatron` flag which will remap the checkpoint to the MCore `GPTModel` spec
+> that we support.
+> Or Users can download [nemo model](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/llama38bnemo) from NGC and extract the sharded checkpoint from the .nemo tarbal.
+
+If users choose to download the model from NGC, first extract the sharded checkpoint from the .nemo tarbal.
+
+```sh
+tar -xvf 8b_pre_trained_bf16.nemo
+```
+
+> **NOTE:** You need a token generated from huggingface.co/settings/tokens and access to meta-llama/Llama-3.1-8B or meta-llama/Llama-3-8B on huggingface
+
+```sh
+pip install -U "huggingface_hub[cli]"
+huggingface-cli login
+```
+
+Now launch the PTQ + TensorRT-LLM checkpoint export script for llama-3,
+
+```sh
+bash examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_8b.sh ./llama-3-8b-nemo_v1.0 None
+```
+
+or llama-3.1
+
+```sh
+bash examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_1_8b.sh ./llama-3_1-8b-nemo_v1.0 None
+```
+
+Then build TensorRT engine and run text generation example using the newly built TensorRT engine
+
+```sh
+export trtllm_options=" \
+    --checkpoint_dir /tmp/trtllm_ckpt \
+    --output_dir /tmp/trtllm_engine \
+    --max_input_len 2048 \
+    --max_seq_len 512 \
+    --max_batch_size 8 "
+
+trtllm-build ${trtllm_options}
+
+python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer meta-llama/Meta-Llama-3-8B
+# For llama-3
+
+python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer meta-llama/Meta-Llama-3.1-8B
+#For llama-3.1
+```
+
+
+### Mixtral-8x7B FP8 Quantization and TensorRT-LLM Deployment
+First download the nemotron checkpoint from https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mixtral-8x7b-v01, extract the
+sharded checkpoint from the `.nemo` tarbal.
+
+```sh
+ngc registry model download-version "nvidia/nemo/mixtral-8x7b-v01:1.0"
+cd mixtral-8x7b-v01_v1.0
+tar -xvf mixtral.nemo
+cd ..
+```
+
+Then log in to huggingface so that you can access to model
+
+> **NOTE:** You need a token generated from huggingface.co/settings/tokens and access to mistralai/Mixtral-8x7B-v0.1 on huggingface
+
+```sh
+pip install -U "huggingface_hub[cli]"
+huggingface-cli login
+```
+
+Now launch the PTQ + TensorRT-LLM checkpoint export script,
+
+```sh
+bash examples/export/ptq_and_trtllm_export/ptq_trtllm_mixtral_8x7b.sh ./mixtral-8x7b-v01_v1.0/
+```
+
+Then build TensorRT engine and run text generation example using the newly built TensorRT engine
+
+```sh
+export trtllm_options=" \
+    --checkpoint_dir /tmp/trtllm_ckpt \
+    --output_dir /tmp/trtllm_engine \
+    --max_input_len 2048 \
+    --max_seq_len 512 \
+    --max_batch_size 8 "
+
+trtllm-build ${trtllm_options}
+
+python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer mistralai/Mixtral-8x7B-v0.1
+```
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama2_7b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama2_7b.sh
new file mode 100644
index 0000000000..ebcc448955
--- /dev/null
+++ b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama2_7b.sh
@@ -0,0 +1,80 @@
+#!/bin/bash
+set -e
+
+DEFAULT_NAME="/checkpoints/llama2-text-7b_v0.2.0"
+NAME="${1:-$DEFAULT_NAME}"
+
+DEFAULT_QUANT_CFG="int8_sq"
+QUANT_CFG="${2:-$DEFAULT_QUANT_CFG}"
+
+# CHANGE THE FOLLOWING IF YOU MOUNT YOUR DATA AND CHECKPOINTS DIFFERENTLY IN THE CONTAINER.
+TP="8"
+INFERENCE_TP=${TP}
+DECODER_TYPE="llama"
+CHECKPOINT_LOAD_DIR="${NAME}"
+TOKENIZER_MODEL="${CHECKPOINT_LOAD_DIR}/hf/tokenizer.model"
+
+# LLaMA2 text 7b has ffn_hidden_size 11008. int4_awq requires a block_size of 128 as a result the TP can at most be 2
+if [ "$QUANT_CFG" = "int4_awq" ]; then
+    INFERENCE_TP="2"
+fi
+
+additional_options=" \
+    --export-quant-cfg ${QUANT_CFG} \
+    --export-legacy-megatron \
+    --export-te-mcore-model \
+    --calib-batch-size 8 \
+    --decoder ${DECODER_TYPE} \
+    --export-dir /tmp/trtllm_ckpt \
+    --inference-tensor-parallel ${INFERENCE_TP} "
+
+trtllm_options=" \
+    --tensorrt-llm-checkpoint-dir /tmp/trtllm_ckpt \
+    --engine-dir /tmp/trtllm_engine \
+    --tokenizer ${CHECKPOINT_LOAD_DIR}/hf \
+    --max-input-len 2048 \
+    --max-output-len 512 \
+    --max-batch-size 8 "
+
+# DO NOT CHANGE THE SETTING BELOW UNLESS YOU KNOW WHAT YOU ARE DOING!!!
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+
+options=" \
+    --disable-bias-linear \
+    --swiglu \
+    --no-rope-fusion \
+    --untie-embeddings-and-output-weights \
+    --use-rotary-position-embeddings \
+    --normalization RMSNorm \
+    --rotary-percent 1.0 \
+    --no-position-embedding \
+    --no-masked-softmax-fusion \
+    --no-bias-gelu-fusion \
+    --no-bias-dropout-fusion \
+    --no-async-tensor-model-parallel-allreduce \
+    --tensor-model-parallel-size ${TP} \
+    --pipeline-model-parallel-size 1 \
+    --num-layers 32 \
+    --hidden-size 4096 \
+    --ffn-hidden-size 11008 \
+    --num-attention-heads 32 \
+    --seq-length 4096 \
+    --max-position-embeddings 4096 \
+    --micro-batch-size 1 \
+    --make-vocab-size-divisible-by 1 \
+    --tokenizer-type Llama2Tokenizer \
+    --tokenizer-model ${TOKENIZER_MODEL} \
+    --save-interval 1000000 \
+    --use-dist-ckpt \
+    --load ${CHECKPOINT_LOAD_DIR} \
+    --fp16"
+
+# Precompile CUDA extentions
+python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_ext); print(ext.cuda_ext_fp8)"
+
+# Acquire launch configuration where variable launch_config will be set
+launch_config="--nproc_per_node=${TP}"
+
+# Launch multi-process with torchrun
+torchrun ${launch_config} examples/export/ptq_and_trtllm_export/text_generation_ptq.py ${options} ${additional_options}
+
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_1_8b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_1_8b.sh
new file mode 100644
index 0000000000..94ee12db41
--- /dev/null
+++ b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_1_8b.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+set -e
+
+DEFAULT_NAME="/checkpoints/llama-3_1-8b-nemo_v1.0"
+NAME="${1:-$DEFAULT_NAME}"
+
+DEFAULT_QUANT_CFG="int8_sq"
+QUANT_CFG="${2:-$DEFAULT_QUANT_CFG}"
+
+# CHANGE THE FOLLOWING IF YOU MOUNT YOUR DATA AND CHECKPOINTS DIFFERENTLY IN THE CONTAINER.
+TP="1"
+INFERENCE_TP=${TP}
+DECODER_TYPE="llama"
+CHECKPOINT_LOAD_DIR="${NAME}"
+
+# LLaMA2 text 7b has ffn_hidden_size 11008. int4_awq requires a block_size of 128 as a result the TP can at most be 2
+if [ "$QUANT_CFG" = "int4_awq" ]; then
+    INFERENCE_TP="2"
+fi
+
+additional_options=" \
+    --export-quant-cfg ${QUANT_CFG} \
+    --export-legacy-megatron \
+    --export-te-mcore-model \
+    --calib-batch-size 8 \
+    --decoder ${DECODER_TYPE} \
+    --export-dir /tmp/trtllm_ckpt \
+    --inference-tensor-parallel ${INFERENCE_TP} "
+
+# DO NOT CHANGE THE SETTING BELOW UNLESS YOU KNOW WHAT YOU ARE DOING!!!
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+
+options=" \
+    --disable-bias-linear \
+    --attention-backend unfused \
+    --swiglu \
+    --no-rope-fusion \
+    --untie-embeddings-and-output-weights \
+    --use-rotary-position-embeddings \
+    --normalization RMSNorm \
+    --rotary-percent 1.0 \
+    --hidden-dropout 0.0 \
+    --attention-dropout 0.0 \
+    --no-bias-gelu-fusion \
+    --no-bias-dropout-fusion \
+    --no-async-tensor-model-parallel-allreduce \
+    --tensor-model-parallel-size ${TP} \
+    --pipeline-model-parallel-size 1 \
+    --num-layers 32 \
+    --hidden-size 4096 \
+    --group-query-attention \
+    --num-query-groups 8 \
+    --ffn-hidden-size 14336 \
+    --num-attention-heads 32 \
+    --seq-length 131072 \
+    --max-position-embeddings 131072 \
+    --micro-batch-size 4 \
+    --make-vocab-size-divisible-by 128 \
+    --tokenizer-type HuggingFaceTokenizer \
+    --tokenizer-model meta-llama/Meta-Llama-3.1-8B \
+    --save-interval 1000000 \
+    --use-rope-scaling \
+    --use-dist-ckpt \
+    --load ${CHECKPOINT_LOAD_DIR} \
+    --rotary-base 500000 \
+    --fp16"
+
+# Precompile CUDA extentions
+python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_ext); print(ext.cuda_ext_fp8)"
+
+# Acquire launch configuration where variable launch_config will be set
+launch_config="--nproc_per_node=${TP}"
+
+# Launch multi-process with torchrun
+torchrun ${launch_config} examples/export/ptq_and_trtllm_export/text_generation_ptq.py ${options} ${additional_options}
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_8b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_8b.sh
new file mode 100644
index 0000000000..dfa5a80c26
--- /dev/null
+++ b/examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_8b.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+set -e
+
+DEFAULT_NAME="/checkpoints/llama-3_1-8b-nemo_v1.0"
+NAME="${1:-$DEFAULT_NAME}"
+
+DEFAULT_QUANT_CFG="int8_sq"
+QUANT_CFG="${2:-$DEFAULT_QUANT_CFG}"
+
+
+# CHANGE THE FOLLOWING IF YOU MOUNT YOUR DATA AND CHECKPOINTS DIFFERENTLY IN THE CONTAINER.
+TP="1"
+INFERENCE_TP=${TP}
+DECODER_TYPE="llama"
+CHECKPOINT_LOAD_DIR="${NAME}"
+
+# LLaMA2 text 7b has ffn_hidden_size 11008. int4_awq requires a block_size of 128 as a result the TP can at most be 2
+if [ "$QUANT_CFG" = "int4_awq" ]; then
+    INFERENCE_TP="2"
+fi
+
+additional_options=" \
+    --export-quant-cfg ${QUANT_CFG} \
+    --export-legacy-megatron \
+    --export-te-mcore-model \
+    --calib-batch-size 8 \
+    --decoder ${DECODER_TYPE} \
+    --export-dir /tmp/trtllm_ckpt \
+    --inference-tensor-parallel ${INFERENCE_TP} "
+
+# DO NOT CHANGE THE SETTING BELOW UNLESS YOU KNOW WHAT YOU ARE DOING!!!
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+
+options=" \
+    --disable-bias-linear \
+    --attention-backend unfused \
+    --swiglu \
+    --no-rope-fusion \
+    --untie-embeddings-and-output-weights \
+    --use-rotary-position-embeddings \
+    --normalization RMSNorm \
+    --rotary-percent 1.0 \
+    --hidden-dropout 0.0 \
+    --attention-dropout 0.0 \
+    --no-bias-gelu-fusion \
+    --no-bias-dropout-fusion \
+    --no-async-tensor-model-parallel-allreduce \
+    --tensor-model-parallel-size ${TP} \
+    --pipeline-model-parallel-size 1 \
+    --num-layers 32 \
+    --hidden-size 4096 \
+    --group-query-attention \
+    --num-query-groups 8 \
+    --ffn-hidden-size 14336 \
+    --num-attention-heads 32 \
+    --seq-length 8192 \
+    --max-position-embeddings 8192 \
+    --micro-batch-size 4 \
+    --make-vocab-size-divisible-by 128 \
+    --tokenizer-type HuggingFaceTokenizer \
+    --tokenizer-model meta-llama/Meta-Llama-3-8B \
+    --save-interval 1000000 \
+    --use-dist-ckpt \
+    --load ${CHECKPOINT_LOAD_DIR} \
+    --rotary-base 500000 \
+    --fp16"
+
+# Precompile CUDA extentions
+python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_ext); print(ext.cuda_ext_fp8)"
+
+# Acquire launch configuration where variable launch_config will be set
+launch_config="--nproc_per_node=${TP}"
+
+# Launch multi-process with torchrun
+torchrun ${launch_config} examples/export/ptq_and_trtllm_export/text_generation_ptq.py ${options} ${additional_options}
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_minitron_8b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_minitron_8b.sh
new file mode 100644
index 0000000000..6e57972e30
--- /dev/null
+++ b/examples/export/ptq_and_trtllm_export/ptq_trtllm_minitron_8b.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+set -e
+
+DEFAULT_NAME="/checkpoints/nemotron3-8b_v0.3.0"
+NAME="${1:-$DEFAULT_NAME}"
+
+DEFAULT_QUANT_CFG="fp8"
+QUANT_CFG="${2:-$DEFAULT_QUANT_CFG}"
+
+# CHANGE THE FOLLOWING IF YOU MOUNT YOUR DATA AND CHECKPOINTS DIFFERENTLY IN THE CONTAINER.
+TP="8"
+INFERENCE_TP=${TP}
+DECODER_TYPE="gptnext"
+CHECKPOINT_LOAD_DIR="${NAME}/nemo"
+
+if [ "$QUANT_CFG" = "int4_awq" ]; then
+    INFERENCE_TP="1"
+fi
+
+additional_options=" \
+    --export-quant-cfg ${QUANT_CFG} \
+    --export-legacy-megatron \
+    --export-te-mcore-model \
+    --calib-batch-size 8 \
+    --decoder ${DECODER_TYPE} \
+    --export-dir /tmp/trtllm_ckpt \
+    --inference-tensor-parallel ${INFERENCE_TP} "
+
+# DO NOT CHANGE THE SETTING BELOW UNLESS YOU KNOW WHAT YOU ARE DOING!!!
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+
+options=" \
+    --apply-layernorm-1p \
+    --attn-attention unfused \
+    --untie-embeddings-and-output-weights \
+    --disable-bias-linear \
+    --no-rope-fusion \
+    --no-position-embedding \
+    --use-rotary-position-embeddings \
+    --rotary-percent 0.5 \
+    --squared-relu \
+    --attention-dropout 0.0 \
+    --hidden-dropout 0.0 \
+    --tensor-model-parallel-size ${TP} \
+    --pipeline-model-parallel-size 1 \
+    --num-layers 32 \
+    --hidden-size 4096 \
+    --ffn-hidden-size 16384 \
+    --group-query-attention \
+    --num-attention-heads 48 \
+    --kv-channels 128 \
+    --seq-length 4096 \
+    --num-query-groups 8 \
+    --max-position-embeddings 4096 \
+    --micro-batch-size 4 \
+    --tokenizer-type HuggingFaceTokenizer \
+    --tokenizer-model nvidia/Minitron-8B-Base \
+    --save-interval 1000000 \
+    --load ${CHECKPOINT_LOAD_DIR} \
+    --bf16 \
+    --use-dist-ckpt"
+
+# Precompile CUDA extentions
+python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_ext); print(ext.cuda_ext_fp8)"
+
+# Acquire launch configuration where variable launch_config will be set
+launch_config="--nproc_per_node=${TP}"
+
+# Launch multi-process with torchrun
+torchrun ${launch_config} examples/export/ptq_and_trtllm_export/text_generation_ptq.py ${options} ${additional_options}
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_mistral_12b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_mistral_12b.sh
new file mode 100644
index 0000000000..8469945f08
--- /dev/null
+++ b/examples/export/ptq_and_trtllm_export/ptq_trtllm_mistral_12b.sh
@@ -0,0 +1,71 @@
+#!/bin/bash
+set -e
+
+DEFAULT_NAME="/checkpoints/Mistral-NeMo-12B-Base"
+NAME="${1:-$DEFAULT_NAME}"
+
+DEFAULT_QUANT_CFG="fp8"
+QUANT_CFG="${2:-$DEFAULT_QUANT_CFG}"
+
+# CHANGE THE FOLLOWING IF YOU MOUNT YOUR DATA AND CHECKPOINTS DIFFERENTLY IN THE CONTAINER.
+TP="8"
+INFERENCE_TP=${TP}
+DECODER_TYPE="llama"
+CHECKPOINT_LOAD_DIR="${NAME}"
+
+if [ "$QUANT_CFG" = "int4_awq" ]; then
+    INFERENCE_TP="1"
+fi
+
+additional_options=" \
+    --export-quant-cfg ${QUANT_CFG} \
+    --export-legacy-megatron \
+    --export-te-mcore-model \
+    --calib-batch-size 8 \
+    --decoder ${DECODER_TYPE} \
+    --export-dir /tmp/trtllm_ckpt \
+    --inference-tensor-parallel ${INFERENCE_TP} "
+
+# DO NOT CHANGE THE SETTING BELOW UNLESS YOU KNOW WHAT YOU ARE DOING!!!
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+
+options=" \
+    --untie-embeddings-and-output-weights \
+    --attention-backend unfused \
+    --disable-bias-linear \
+    --use-rotary-position-embeddings \
+    --rotary-percent 1.0 \
+    --attention-dropout 0.0 \
+    --hidden-dropout 0.0 \
+    --tensor-model-parallel-size ${TP} \
+    --pipeline-model-parallel-size 1 \
+    --num-layers 40 \
+    --hidden-size 5120 \
+    --ffn-hidden-size 14336 \
+    --num-attention-heads 32 \
+    --seq-length 8192 \
+    --kv-channels 128 \
+    --normalization RMSNorm \
+    --swiglu \
+    --num-query-groups 8 \
+    --group-query-attention \
+    --position-embedding-type rope \
+    --max-position-embeddings 8192 \
+    --micro-batch-size 1 \
+    --tokenizer-type HuggingFaceTokenizer \
+    --tiktoken-pattern v2 \
+    --tokenizer-model mistralai/Mistral-Nemo-Base-2407 \
+    --save-interval 1000000 \
+    --load ${CHECKPOINT_LOAD_DIR} \
+    --fp16 \
+    --rotary-base 1000000 \
+    --use-dist-ckpt"
+
+# Precompile CUDA extentions
+python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_ext); print(ext.cuda_ext_fp8)"
+
+# Acquire launch configuration where variable launch_config will be set
+launch_config="--nproc_per_node=${TP}"
+
+# Launch multi-process with torchrun
+torchrun ${launch_config} examples/export/ptq_and_trtllm_export/text_generation_ptq.py ${options} ${additional_options}
diff --git a/examples/export/ptq_and_trtllm_export/ptq_trtllm_mixtral_8x7b.sh b/examples/export/ptq_and_trtllm_export/ptq_trtllm_mixtral_8x7b.sh
new file mode 100644
index 0000000000..d2a4edee47
--- /dev/null
+++ b/examples/export/ptq_and_trtllm_export/ptq_trtllm_mixtral_8x7b.sh
@@ -0,0 +1,84 @@
+#!/bin/bash
+set -e
+
+DEFAULT_NAME="/checkpoints/Mistral-NeMo-12B-Base"
+NAME="${1:-$DEFAULT_NAME}"
+
+DEFAULT_QUANT_CFG="fp8"
+QUANT_CFG="${2:-$DEFAULT_QUANT_CFG}"
+
+# NOTE: UNFUSED ATTENTION MUST BE USED TO AVOID ADDITIONAL STATE_DICT KEY MISMATCH.
+export NVTE_FLASH_ATTN=0
+export NVTE_FUSED_ATTN=0
+export NVTE_UNFUSED_ATTN=1
+
+# CHANGE THE FOLLOWING IF YOU MOUNT YOUR DATA AND CHECKPOINTS DIFFERENTLY IN THE CONTAINER.
+TP="8"
+INFERENCE_TP=${TP}
+DECODER_TYPE="llama"
+CHECKPOINT_LOAD_DIR="${NAME}"
+
+if [ "$QUANT_CFG" = "int4_awq" ]; then
+    INFERENCE_TP="1"
+fi
+
+additional_options=" \
+    --export-quant-cfg ${QUANT_CFG} \
+    --export-legacy-megatron \
+    --export-te-mcore-model \
+    --calib-batch-size 8 \
+    --decoder ${DECODER_TYPE} \
+    --export-dir /tmp/trtllm_ckpt \
+    --inference-tensor-parallel ${INFERENCE_TP} "
+
+# DO NOT CHANGE THE SETTING BELOW UNLESS YOU KNOW WHAT YOU ARE DOING!!!
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+
+options=" \
+    --untie-embeddings-and-output-weights \
+    --no-masked-softmax-fusion \
+    --no-position-embedding \
+    --use-mcore-models \
+    --disable-bias-linear \
+    --rotary-percent 1.0 \
+    --attention-dropout 0.0 \
+    --hidden-dropout 0.0 \
+    --tensor-model-parallel-size ${TP} \
+    --pipeline-model-parallel-size 1 \
+    --num-layers 32 \
+    --hidden-size 4096 \
+    --ffn-hidden-size 14336 \
+    --num-attention-heads 32 \
+    --seq-length 4096 \
+    --kv-channels 128 \
+    --normalization RMSNorm \
+    --swiglu \
+    --num-query-groups 8 \
+    --num-experts 8 \
+    --moe-router-topk 2 \
+    --moe-aux-loss-coeff 1e-2 \
+    --moe-router-load-balancing-type aux_loss \
+    --group-query-attention \
+    --position-embedding-type rope \
+    --no-rope-fusion \
+    --max-position-embeddings 32768 \
+    --micro-batch-size 1 \
+    --tokenizer-type HuggingFaceTokenizer \
+    --tiktoken-pattern v2 \
+    --tokenizer-model mistralai/Mixtral-8x7B-Instruct-v0.1 \
+    --save-interval 1000000 \
+    --load ${CHECKPOINT_LOAD_DIR} \
+    --bf16 \
+    --rotary-base 1000000 \
+    --use-dist-ckpt"
+
+# Precompile CUDA extentions
+python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_ext); print(ext.cuda_ext_fp8)"
+
+# Acquire launch configuration where variable launch_config will be set
+launch_config="--nproc_per_node=${TP}"
+
+# Launch multi-process with torchrun
+torchrun ${launch_config} examples/export/ptq_and_trtllm_export/text_generation_ptq.py ${options} ${additional_options}
+
+
diff --git a/examples/export/ptq_and_trtllm_export/text_generation_ptq.py b/examples/export/ptq_and_trtllm_export/text_generation_ptq.py
new file mode 100644
index 0000000000..c915cec790
--- /dev/null
+++ b/examples/export/ptq_and_trtllm_export/text_generation_ptq.py
@@ -0,0 +1,222 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+
+"""Sample Generate GPT."""
+import functools
+import os
+import sys
+from pathlib import Path
+
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
+
+import modelopt.torch.quantization as mtq
+import torch
+from datasets import load_dataset
+from tqdm import tqdm
+
+# [ModelOpt]: changing the default model provider to the ModelOpt version
+from megatron.core import mpu
+from megatron.inference.arguments import add_modelopt_args
+from megatron.inference.checkpointing import load_modelopt_checkpoint
+from megatron.inference.gpt.model_provider import model_provider
+from megatron.inference.text_generation import generate_and_post_process
+from megatron.training import get_args, get_model, initialize_megatron
+from megatron.training.checkpointing import save_checkpoint
+from megatron.training.utils import print_rank_0, unwrap_model
+
+QUANT_CFG_CHOICES = {
+    "int8": mtq.INT8_DEFAULT_CFG,
+    "int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
+    "fp8": mtq.FP8_DEFAULT_CFG,
+    "int4_awq": mtq.INT4_AWQ_CFG,
+    "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
+    "int4": mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG,
+}
+
+
+def add_trtllm_ckpt_export_args(parser):
+    """Add additional arguments for TensorRT-LLM."""
+    group = parser.add_argument_group(title="trtllm")
+
+    group.add_argument(
+        "--export-dir", type=str, help="The output TensorRT-LLM checkpoint.",
+    )
+    group.add_argument(
+        "--decoder", type=str, choices=["gptnext", 'llama'], help="The decoder type of the model.",
+    )
+    group.add_argument(
+        "--inference-tensor-parallel",
+        type=int,
+        help="Tensor parallel for the inference time, can be different from the training config.",
+        default=1,
+    )
+
+
+def add_text_generate_ptq_args(parser):
+    """Add additional arguments for ModelOpt text generation PTQ."""
+    group = parser.add_argument_group(title='ModelOpt text generation ptq')
+    group.add_argument(
+        "--calib-dataset",
+        type=str,
+        default="cnn_dailymail",
+        help="Calibration datasets from HuggingFace datasets.",
+    )
+    group.add_argument(
+        "--calib-batch-size", type=int, default=4, help="Batch size to use for ptq calibration."
+    )
+    group.add_argument(
+        "--calib-size", type=int, default=512, help="Samples to use for ptq calibration."
+    )
+    parser.add_argument(
+        "--prompts",
+        type=str,
+        default=(
+            "Born in north-east France, Soyer trained as a|Born in California, Soyer trained as a"
+        ),
+        help="Input texts. Please use | to separate different batches.",
+    )
+    add_modelopt_args(parser)
+    add_trtllm_ckpt_export_args(parser)
+    return parser
+
+
+def get_calib_dataloader(
+    data="cnn_dailymail", batch_size=4, calib_size=512, max_sequence_length=512
+):
+    if data == "pileval":
+        dataset = load_dataset(
+            "json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train"
+        )
+        text_column = "text"
+    elif data == "wikitext":
+        dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
+        text_column = "text"
+    elif data == "cnn_dailymail":
+        dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
+        text_column = "article"
+
+    calib_size = max(min(len(dataset), calib_size), batch_size)
+    for i in range(calib_size // batch_size):
+        batch = dataset[i * batch_size : (i + 1) * batch_size][text_column]
+        for j in range(len(batch)):
+            batch[j] = batch[j][:max_sequence_length]
+        yield batch
+
+
+
+if __name__ == "__main__":
+    initialize_megatron(
+        extra_args_provider=add_text_generate_ptq_args,
+        args_defaults={
+            'tokenizer_type': 'GPT2BPETokenizer',
+            'no_load_rng': True,
+            'no_load_optim': True,
+        },
+    )
+
+    args = get_args()
+    if args.num_layers_per_virtual_pipeline_stage is not None:
+        print_rank_0("Interleaved pipeline schedule is not yet supported for text generation.")
+        exit()
+
+    print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text generation.")
+    args.exit_on_missing_checkpoint = True
+    if hasattr(args, 'moe_grouped_gemm') and args.moe_grouped_gemm == True:
+        print_rank_0("WARNING: Forcing moe_grouped_gemm to False for PTQ and export.")
+        args.moe_grouped_gemm = False
+
+    # Set up model and load checkpoint
+    # [ModelOpt]: make sure that output logits are allgathered.
+    text_generation_model_provider = functools.partial(model_provider, parallel_output=False)
+    model = get_model(text_generation_model_provider, wrap_with_ddp=False)
+
+    if args.load is not None:
+        load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights)
+        print_rank_0("Done loading checkpoint")
+
+    # Removing virtual pipeline parallel and other wrapper
+    assert len(model) == 1, "Above condition should have caught this"
+    unwrapped_model = unwrap_model(model)
+
+    all_prompts = args.prompts.split("|")
+
+    def custom_prompt_forward_loop_func(model):
+        for prompt in tqdm(all_prompts):
+            if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
+                (
+                    prompts_plus_generations,
+                    prompts_plus_generations_segments,
+                    logprobs,
+                    _,
+                ) = generate_and_post_process(
+                    model,
+                    prompts=[prompt],
+                    tokens_to_generate=128,
+                    return_output_log_probs=True,
+                    temperature=1.0,
+                )
+                print_rank_0(prompts_plus_generations)
+            else:
+                generate_and_post_process(model)
+
+    def hf_dataset_forword_loop_func(model):
+        dataloader = get_calib_dataloader(args.calib_dataset, args.calib_batch_size, args.calib_size)
+        for prompts in tqdm(dataloader, total=args.calib_size//args.calib_batch_size):
+            if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
+                (
+                    prompts_plus_generations,
+                    prompts_plus_generations_segments,
+                    logprobs,
+                    _,
+                ) = generate_and_post_process(
+                    model,
+                    prompts=prompts,
+                    tokens_to_generate=0,
+                    return_output_log_probs=False,
+                    temperature=1.0,
+                )
+            else:
+                generate_and_post_process(model)
+
+    ptq_forward_loop_func = custom_prompt_forward_loop_func
+    if args.calib_dataset is not None:
+        ptq_forward_loop_func = hf_dataset_forword_loop_func
+
+    if args.export_quant_cfg in QUANT_CFG_CHOICES:
+        mtq_config = QUANT_CFG_CHOICES[args.export_quant_cfg]
+        if "*output_layer*" not in mtq_config["quant_cfg"]:
+            mtq_config["quant_cfg"]["*output_layer*"] = {"enable": False}
+        if "awq" in args.export_quant_cfg:
+            weight_quantizer = mtq_config["quant_cfg"]["*weight_quantizer"]  # type: ignore
+            if isinstance(weight_quantizer, list):
+                weight_quantizer = weight_quantizer[0]
+            weight_quantizer["block_sizes"][-1] = 128
+        print_rank_0("Quantizing the model...")
+        mtq.quantize(unwrapped_model[0], mtq_config, ptq_forward_loop_func)
+
+    custom_prompt_forward_loop_func(model[0])
+
+    if args.save is not None and args.export_quant_cfg in QUANT_CFG_CHOICES:
+        save_checkpoint(1, unwrapped_model, None, None, 0)
+
+    print_rank_0(f"Fake Quantized Model:\n {unwrapped_model[0]}")
+
+    if args.export_dir:
+        assert args.decoder in ["gptnext", "llama"], f"Decoder type {args.decoder} not supported."
+        Path(args.export_dir).mkdir(parents=True, exist_ok=True)
+        print_rank_0("Exporting TensorRT-LLM checkpoints.")
+
+        from modelopt.torch.export import export_tensorrt_llm_checkpoint
+
+        # In TRT LLM, squared relu activation does not support bf16. So we use fp16 by default.
+        export_tensorrt_llm_checkpoint(
+            unwrapped_model[0],
+            args.decoder,
+            torch.bfloat16 if args.bf16 else torch.float16,
+            export_dir=args.export_dir,
+            inference_tensor_parallel=args.inference_tensor_parallel,
+            inference_pipeline_parallel=1,
+            use_nfs_workspace=True,
+        )
+
+        print_rank_0(f"TensorRT-LLM checkpoints saved to {args.export_dir}")
+        torch.distributed.barrier()
diff --git a/examples/export/ptq_and_trtllm_export/trtllm_text_generation.py b/examples/export/ptq_and_trtllm_export/trtllm_text_generation.py
new file mode 100644
index 0000000000..ab8aa25a96
--- /dev/null
+++ b/examples/export/ptq_and_trtllm_export/trtllm_text_generation.py
@@ -0,0 +1,64 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+
+"""An example script to run the tensorrt_llm engine."""
+
+import argparse
+from pathlib import Path
+import subprocess
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from modelopt.deploy.llm import LLM
+from tensorrt_llm.models import PretrainedConfig
+from transformers import AutoTokenizer, T5Tokenizer
+import tensorrt_llm
+
+
+def parse_arguments():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--tokenizer", type=str, default="")
+    parser.add_argument("--engine-dir", type=str, default="/tmp/trtllm_engine")
+    parser.add_argument(
+        "--input-texts",
+        type=str,
+        default=(
+            "Born in north-east France, Soyer trained as a|Born in California, Soyer trained as a"
+        ),
+        help="Input texts. Please use | to separate different batches.",
+    )
+    return parser.parse_args()
+
+
+def run(args):
+    try:
+        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
+    except Exception as e:
+        raise Exception(f"Failed to load tokenizer: {e}")
+
+    print(tokenizer, tokenizer.vocab_size)
+
+    input_texts = args.input_texts.split("|")
+    assert input_texts, "input_text not specified"
+    print(input_texts)
+
+    free_memory_before = torch.cuda.mem_get_info()
+
+    # This is a ModelOpt wrapper on top of tensorrt_llm.hlapi.llm.LLM
+    llm_engine = LLM(args.engine_dir, tokenizer)
+
+    torch.cuda.cudart().cudaProfilerStart()
+    # outputs = llm_engine.generate_text(input_texts, args.max_output_len, args.max_beam_width)
+    outputs = llm_engine.generate(input_texts)
+    torch.cuda.cudart().cudaProfilerStop()
+
+    free_memory_after = torch.cuda.mem_get_info()
+    print(
+        f"Used GPU memory: {(free_memory_before[0] - free_memory_after[0]) / 1024 / 1024 / 1024} GB"
+    )
+    print(outputs)
+
+
+if __name__ == "__main__":
+    args = parse_arguments()
+    run(args)
diff --git a/examples/export/trtllm_export/README.md b/examples/export/trtllm_export/README.md
new file mode 100644
index 0000000000..52cad78583
--- /dev/null
+++ b/examples/export/trtllm_export/README.md
@@ -0,0 +1,161 @@
+# Megatron Core To TRTLLM Export Documentation
+This guide will walk you through how you can use the megatron core export for exporting models to trtllm format
+
+### Contents
+- [Megatron Core To TRTLLM Export Documentation](#megatron-core-to-trtllm-export-documentation)
+- [Contents](#contents)
+  - [1. Quick Start](#1-quick-start)
+    - [1.1 Understanding The Code](#11-understanding-the-code)
+    - [1.2 Running The Code](#12-running-the-code)
+  - [2. GPU Export](#2-gpu-export)
+  - [3. Future work](#4-future-work)
+
+#### 1. Quick Start
+This will walk you through the flow of converting an mcore gpt model to trtllm format using single device mode. The file can be found at [gpt_single_device_cpu_export.py](./single_device_export/gpt_single_device_cpu_export.py)
+
+NOTE: For faster performance, if your entire model will fit into gpu memory, pre transfer the model state dict to gpu and then call the get_trtllm_pretrained_config_and_model_weights function.
+
+
+ +##### 1.1 Understanding The Code +***STEP 1 - We initialize model parallel and other default arguments*** +We initalize tp and pp to 1 so that we can get the full model state dict on cpu +```python + initialize_distributed(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) +``` + +***STEP 2 - We load the model using the model_provider_function*** +NOTE: We create a simple gpt model + +```python + transformer_config = TransformerConfig( + num_layers=2, + hidden_size=64, # Needs to be atleast 32 times num_attn_heads + num_attention_heads=2, + use_cpu_initialization=True, + pipeline_dtype=torch.float32, + ) + + gpt_model = GPTModel( + config=transformer_config, + transformer_layer_spec=get_gpt_layer_local_spec(), + vocab_size=100, + max_sequence_length=_SEQUENCE_LENGTH, + ) + + # Optionally you can also load a model using this code + # sharded_state_dict=gpt_model.sharded_state_dict(prefix='') + # checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path) + # gpt_model.load_state_dict(checkpoint) + +``` + +***STEP 3 - Instantiate the TRTLLM Helper*** +We instantiate the [TRTLLM Helper](../../../megatron/core/export/trtllm/trtllm_helper.py) For the GPT model we instantiate trtllm_helper as shown below. +```python + if hasattr(gpt_model, "rotary_pos_emb"): + seq_len_interpolation_factor = gpt_model.rotary_pos_emb.seq_len_interpolation_factor + + trtllm_helper = TRTLLMHelper( + transformer_config=gpt_model.config, + model_type=ModelType.gpt, + position_embedding_type = gpt_model.position_embedding_type, + max_position_embeddings = gpt_model.max_position_embeddings, + rotary_percentage = gpt_model.rotary_percent, + rotary_base = gpt_model.rotary_base, + moe_tp_mode = 2, + multi_query_mode = False, + activation = "gelu", + seq_len_interpolation_factor = seq_len_interpolation_factor, + share_embeddings_and_output_weights=gpt_model.share_embeddings_and_output_weights + ) +``` + +***STEP 4 - Get the TRTLLM Weights and configs*** +To convert model weights to trtllm weights and configs, we use the [single_device_converter](../../../megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py). We pass as inputs the model state dict, and export config. In this example we use inference tp size as 2 for the export. + +```python + model_state_dict={} + for key , val in gpt_model.state_dict().items(): + # val is non for _extra_state layers . We filter it out + if val is not None: + model_state_dict[key] = val + + export_config = ExportConfig(inference_tp_size = 2) + weight_list, config_list = trtllm_helper.get_trtllm_pretrained_config_and_model_weights( + model_state_dict= model_state_dict, + dtype = DataType.bfloat16, + export_config=export_config + ) +``` + +***STEP 5 - Build the TRTLLM Engine*** +Following code is used to build the TRTLLM Engine. + +```python + for trtllm_model_weights, trtllm_model_config in zip(weight_list, config_list): + trtllm_helper.build_and_save_engine( + max_input_len=256, + max_output_len=256, + max_batch_size=8, + engine_dir='/opt/megatron-lm/engine', + trtllm_model_weights=trtllm_model_weights, + trtllm_model_config=trtllm_model_config, + lora_ckpt_list=None, + use_lora_plugin=None, + max_lora_rank=64, + lora_target_modules=None, + max_prompt_embedding_table_size=0, + paged_kv_cache=True, + remove_input_padding=True, + paged_context_fmha=False, + use_refit=False, + max_num_tokens=None, + max_seq_len=512, + opt_num_tokens=None, + max_beam_width=1, + tokens_per_block=128, + multiple_profiles=False, + gpt_attention_plugin="auto", + gemm_plugin="auto", + ) +``` +
+ +##### 1.2 Running The Code +An example run script is shown below. + +``` +# In a workstation +MLM_PATH=/path/to/megatron-lm +CONTAINER_IMAGE=gitlab-master.nvidia.com:5005/dl/joc/nemo-ci/trtllm_0.12/train:pipe.17669124-x86 + +docker run -it --gpus=all --ipc=host -v $MLM_PATH/:/opt/megatron-lm $CONTAINER_IMAGE bash + +# Inside the container run the following. + +cd /opt/megatron-lm/ + +CUDA_VISIBLE_DEVICES=0 torchrun --nproc-per-node 1 examples/export/trtllm_export/single_device_export/gpt_single_device_cpu_export.py +``` + +
+ +#### 2. GPU Export +You can use the [gpt_distributed_gpu_export.py](./distributed_export/gpt_distributed_gpu_export.py) to run a more optimized on device distributed. version of trtllm export. Internally this uses the [distributed_converter](../../../megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py) to convert model weights on device. +In the single device version you collect all the model weights on CPU/GPU, convert it to trtllm format, and then store the engine back on disk. In the GPU version you load each individual state dict on the gpus, convert it on the device itself and store the engine on disk. + +To run the gpu version + +``` +CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc-per-node 2 examples/export/trtllm_export/distributed_export/gpt_distributed_gpu_export.py +``` + +
+ +#### 3. Future work +The following are planned for the future releases . +* Pipeline parallellism for export (Work in progress) +* GPU Export for more models (Work in progress for some models) +* Refit functionality +* VLLM Support \ No newline at end of file diff --git a/examples/export/trtllm_export/distributed_export/gpt_distributed_gpu_export.py b/examples/export/trtllm_export/distributed_export/gpt_distributed_gpu_export.py new file mode 100644 index 0000000000..57d44f9f62 --- /dev/null +++ b/examples/export/trtllm_export/distributed_export/gpt_distributed_gpu_export.py @@ -0,0 +1,117 @@ +import os +import torch +from megatron.core import parallel_state +from megatron.core import dist_checkpointing +from megatron.core.export.model_type import ModelType +from megatron.core.export.data_type import DataType +from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec + + +_SEQUENCE_LENGTH = 64 +_VOCAB_SIZE = 256 + +def initialize_distributed(tensor_model_parallel_size=1, pipeline_model_parallel_size=1): + parallel_state.destroy_model_parallel() + + # Torch setup for distributed training + rank = int(os.environ['LOCAL_RANK']) + world_size = torch.cuda.device_count() + torch.cuda.set_device(rank) + torch.distributed.init_process_group(world_size=world_size, rank=rank) + + # Megatron core distributed training initialization + parallel_state.initialize_model_parallel(tensor_model_parallel_size = tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size) + +def model_provider(): + """Build the model.""" + + transformer_config = TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=2, + use_cpu_initialization=True, + pipeline_dtype=torch.float32 + ) + + gpt_model = GPTModel( + config=transformer_config, + transformer_layer_spec=get_gpt_layer_local_spec(), + vocab_size=_VOCAB_SIZE, + max_sequence_length=_SEQUENCE_LENGTH, + ) + + return gpt_model + +def load_distributed_checkpoint(checkpoint_path, gpt_model): + sharded_state_dict=gpt_model.sharded_state_dict(prefix='') + checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path) + gpt_model.load_state_dict(checkpoint) + return gpt_model + +if __name__ == "__main__": + initialize_distributed(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) + model_parallel_cuda_manual_seed(123) + + gpt_model = model_provider() + device = torch.device("cuda") + gpt_model.to(device) + + # Optionally you can also load a gpt model from ckpt_path using this code below + # gpt_model = load_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path=ckpt_path) + + seq_len_interpolation_factor = None + if hasattr(gpt_model, "rotary_pos_emb"): + seq_len_interpolation_factor = gpt_model.rotary_pos_emb.seq_len_interpolation_factor + + trtllm_helper = TRTLLMHelper( + transformer_config=gpt_model.config, + model_type=ModelType.gpt, + position_embedding_type = gpt_model.position_embedding_type, + max_position_embeddings = gpt_model.max_position_embeddings, + rotary_percentage = gpt_model.rotary_percent, + rotary_base = gpt_model.rotary_base, + moe_tp_mode = 2, + multi_query_mode = False, + activation = "gelu", + seq_len_interpolation_factor = seq_len_interpolation_factor, + share_embeddings_and_output_weights=gpt_model.share_embeddings_and_output_weights + ) + + + trtllm_model_weights, trtllm_model_config = trtllm_helper.get_trtllm_pretrained_config_and_model_weights( + model_state_dict= gpt_model.state_dict(), + dtype = DataType.bfloat16, + on_device_distributed_conversion=True, + vocab_size=_VOCAB_SIZE, + gpus_per_node=2, + ) + + trtllm_helper.build_and_save_engine( + max_input_len=256, + max_output_len=256, + max_batch_size=8, + engine_dir='/opt/megatron-lm/engine', + trtllm_model_weights=trtllm_model_weights[0], + trtllm_model_config=trtllm_model_config[0], + lora_ckpt_list=None, + use_lora_plugin=None, + max_lora_rank=64, + lora_target_modules=None, + max_prompt_embedding_table_size=0, + paged_kv_cache=True, + remove_input_padding=True, + paged_context_fmha=False, + use_refit=False, + max_num_tokens=None, + max_seq_len=512, + opt_num_tokens=None, + max_beam_width=1, + tokens_per_block=128, + multiple_profiles=False, + gpt_attention_plugin="auto", + gemm_plugin="auto", + ) diff --git a/examples/export/trtllm_export/single_device_export/gpt_single_device_cpu_export.py b/examples/export/trtllm_export/single_device_export/gpt_single_device_cpu_export.py new file mode 100644 index 0000000000..587e7cfdd3 --- /dev/null +++ b/examples/export/trtllm_export/single_device_export/gpt_single_device_cpu_export.py @@ -0,0 +1,118 @@ +import os +import torch +from megatron.core import parallel_state +from megatron.core import dist_checkpointing +from megatron.core.export.model_type import ModelType +from megatron.core.export.data_type import DataType +from megatron.core.export.export_config import ExportConfig +from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec + + +_SEQUENCE_LENGTH = 64 + + +def initialize_distributed(tensor_model_parallel_size=1, pipeline_model_parallel_size=1): + parallel_state.destroy_model_parallel() + + # Torch setup for distributed training + rank = int(os.environ['LOCAL_RANK']) + world_size = torch.cuda.device_count() + torch.cuda.set_device(rank) + torch.distributed.init_process_group(world_size=world_size, rank=rank) + + # Megatron core distributed training initialization + parallel_state.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size) + +def model_provider(): + """Build the model.""" + + transformer_config = TransformerConfig( + num_layers=2, + hidden_size=64, # Needs to be atleast 32 times num_attn_heads + num_attention_heads=2, + use_cpu_initialization=True, + pipeline_dtype=torch.float32, + ) + + gpt_model = GPTModel( + config=transformer_config, + transformer_layer_spec=get_gpt_layer_local_spec(), + vocab_size=100, + max_sequence_length=_SEQUENCE_LENGTH, + ) + + return gpt_model + +def load_distributed_checkpoint(checkpoint_path, gpt_model): + sharded_state_dict=gpt_model.sharded_state_dict(prefix='') + checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path) + gpt_model.load_state_dict(checkpoint) + return gpt_model + +if __name__ == "__main__": + # Need to use TP1 PP1 for export on single device + initialize_distributed(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + model_parallel_cuda_manual_seed(123) + + gpt_model = model_provider() + + # Optionally you can also load a gpt model from ckpt_path using this code below + # gpt_model = load_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path=ckpt_path) + + seq_len_interpolation_factor = None + if hasattr(gpt_model, "rotary_pos_emb"): + seq_len_interpolation_factor = gpt_model.rotary_pos_emb.seq_len_interpolation_factor + + trtllm_helper = TRTLLMHelper( + transformer_config=gpt_model.config, + model_type=ModelType.gpt, + position_embedding_type = gpt_model.position_embedding_type, + max_position_embeddings = gpt_model.max_position_embeddings, + rotary_percentage = gpt_model.rotary_percent, + rotary_base = gpt_model.rotary_base, + moe_tp_mode = 2, + multi_query_mode = False, + activation = "gelu", + seq_len_interpolation_factor = seq_len_interpolation_factor, + share_embeddings_and_output_weights=gpt_model.share_embeddings_and_output_weights + ) + + + export_config = ExportConfig(inference_tp_size = 2) + # NOTE : For faster performance, if your entire model will fit in gpu memory, transfer model state dict to GPU and then call this api + weight_list, config_list = trtllm_helper.get_trtllm_pretrained_config_and_model_weights( + model_state_dict= gpt_model.state_dict(), + dtype = DataType.bfloat16, + export_config=export_config + ) + + for trtllm_model_weights, trtllm_model_config in zip(weight_list, config_list): + trtllm_helper.build_and_save_engine( + max_input_len=256, + max_output_len=256, + max_batch_size=8, + engine_dir='/opt/megatron-lm/engine', + trtllm_model_weights=trtllm_model_weights, + trtllm_model_config=trtllm_model_config, + lora_ckpt_list=None, + use_lora_plugin=None, + max_lora_rank=64, + lora_target_modules=None, + max_prompt_embedding_table_size=0, + paged_kv_cache=True, + remove_input_padding=True, + paged_context_fmha=False, + use_refit=False, + max_num_tokens=None, + max_seq_len=512, + opt_num_tokens=None, + max_beam_width=1, + tokens_per_block=128, + multiple_profiles=False, + gpt_attention_plugin="auto", + gemm_plugin="auto", + ) \ No newline at end of file diff --git a/examples/finetune_mnli_distributed.sh b/examples/finetune_mnli_distributed.sh deleted file mode 100755 index a3f9accbcc..0000000000 --- a/examples/finetune_mnli_distributed.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -WORLD_SIZE=8 - -DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr localhost \ - --master_port 6000" - -TRAIN_DATA="data/glue_data/MNLI/train.tsv" -VALID_DATA="data/glue_data/MNLI/dev_matched.tsv \ - data/glue_data/MNLI/dev_mismatched.tsv" -PRETRAINED_CHECKPOINT=checkpoints/bert_345m -VOCAB_FILE=bert-vocab.txt -CHECKPOINT_PATH=checkpoints/bert_345m_mnli - -python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ - --task MNLI \ - --seed 1234 \ - --train-data $TRAIN_DATA \ - --valid-data $VALID_DATA \ - --tokenizer-type BertWordPieceLowerCase \ - --vocab-file $VOCAB_FILE \ - --epochs 5 \ - --pretrained-checkpoint $PRETRAINED_CHECKPOINT \ - --tensor-model-parallel-size 1 \ - --num-layers 24 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --micro-batch-size 8 \ - --lr 5.0e-5 \ - --lr-decay-style linear \ - --lr-warmup-fraction 0.065 \ - --seq-length 512 \ - --max-position-embeddings 512 \ - --save-interval 500000 \ - --save $CHECKPOINT_PATH \ - --log-interval 10 \ - --eval-interval 100 \ - --eval-iters 50 \ - --weight-decay 1.0e-1 \ - --fp16 diff --git a/examples/finetune_race_distributed.sh b/examples/finetune_race_distributed.sh deleted file mode 100755 index 3d92253388..0000000000 --- a/examples/finetune_race_distributed.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash - -WORLD_SIZE=8 - -DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr localhost \ - --master_port 6000" - -TRAIN_DATA="data/RACE/train/middle" -VALID_DATA="data/RACE/dev/middle \ - data/RACE/dev/high" -VOCAB_FILE=bert-vocab.txt -PRETRAINED_CHECKPOINT=checkpoints/bert_345m -CHECKPOINT_PATH=checkpoints/bert_345m_race - -python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ - --task RACE \ - --seed 1234 \ - --train-data $TRAIN_DATA \ - --valid-data $VALID_DATA \ - --tokenizer-type BertWordPieceLowerCase \ - --vocab-file $VOCAB_FILE \ - --epochs 3 \ - --pretrained-checkpoint $PRETRAINED_CHECKPOINT \ - --tensor-model-parallel-size 1 \ - --num-layers 24 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --micro-batch-size 4 \ - --lr 1.0e-5 \ - --lr-decay-style linear \ - --lr-warmup-fraction 0.06 \ - --seq-length 512 \ - --max-position-embeddings 512 \ - --save-interval 100000 \ - --save $CHECKPOINT_PATH \ - --log-interval 10 \ - --eval-interval 100 \ - --eval-iters 50 \ - --weight-decay 1.0e-1 \ - --clip-grad 1.0 \ - --hidden-dropout 0.1 \ - --attention-dropout 0.1 \ - --fp16 diff --git a/examples/finetune_retriever_distributed.sh b/examples/finetune_retriever_distributed.sh deleted file mode 100755 index 535a2e053d..0000000000 --- a/examples/finetune_retriever_distributed.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/bin/bash - -# Finetune a BERT or pretrained ICT model using Google natural question data -# Datasets can be downloaded from the following link: -# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py - -WORLD_SIZE=8 - -DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr localhost \ - --master_port 6000" - -CHECKPOINT_PATH= - -# Load either of the below -BERT_LOAD_PATH= -PRETRAINED_CHECKPOINT= - -python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ - --task RET-FINETUNE-NQ \ - --train-with-neg \ - --train-hard-neg 1 \ - --pretrained-checkpoint ${PRETRAINED_CHECKPOINT} \ - --num-layers 12 \ - --hidden-size 768 \ - --num-attention-heads 12 \ - --tensor-model-parallel-size 1 \ - --tokenizer-type BertWordPieceLowerCase \ - --train-data nq-train.json \ - --valid-data nq-dev.json \ - --save ${CHECKPOINT_PATH} \ - --load ${CHECKPOINT_PATH} \ - --vocab-file bert-vocab.txt \ - --bert-load ${BERT_LOAD_PATH} \ - --save-interval 5000 \ - --log-interval 10 \ - --eval-interval 20000 \ - --eval-iters 100 \ - --indexer-log-interval 1000 \ - --faiss-use-gpu \ - --DDP-impl torch \ - --fp16 \ - --retriever-report-topk-accuracies 1 5 10 20 100 \ - --seq-length 512 \ - --retriever-seq-length 256 \ - --max-position-embeddings 512 \ - --retriever-score-scaling \ - --epochs 80 \ - --micro-batch-size 8 \ - --eval-micro-batch-size 16 \ - --indexer-batch-size 128 \ - --lr 2e-5 \ - --lr-warmup-fraction 0.01 \ - --weight-decay 1e-1 diff --git a/examples/gpt3/README.md b/examples/gpt3/README.md index fec51e1fea..8d6f267416 100644 --- a/examples/gpt3/README.md +++ b/examples/gpt3/README.md @@ -10,7 +10,7 @@ To run the model using a docker container run it as follows ``` -PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.09-py3 +PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.01-py3 CHECKPOINT_PATH="" # TENSORBOARD_LOGS_PATH=""# VOCAB_FILE="" #/gpt2-vocab.json @@ -23,8 +23,8 @@ docker run \ --workdir /workspace/megatron-lm \ -v /path/to/data:/path/to/data \ -v /path/to/megatron-lm:/workspace/megatron-lm \ - megatron-lm nvcr.io/nvidia/pytorch:23.04-py3 \ - bash /examples/gpt3/train_gpt3_175b_distributed.sh $CHECKPOINT_PATH $TENSORBOARD_LOGS_PATH $VOCAB_FILE $MERGE_FILE $DATA_PATH " + megatron-lm nvcr.io/nvidia/pytorch:24.01-py3 \ + bash examples/gpt3/train_gpt3_175b_distributed.sh $CHECKPOINT_PATH $TENSORBOARD_LOGS_PATH $VOCAB_FILE $MERGE_FILE $DATA_PATH " ``` NOTE: Depending on the environment you are running it the above command might like slightly different. @@ -34,7 +34,7 @@ NOTE: Depending on the environment you are running it the above command might li The example in this folder shows you how to run 175B model. There are other configs you could run as well -### 345M +### 345M ``` --num-layers 12 \ --hidden-size 512 \ @@ -45,7 +45,7 @@ The example in this folder shows you how to run 175B model. There are other conf ``` -### 857M +### 857M ``` --num-layers 24 \ --hidden-size 1024 \ diff --git a/examples/gpt3/gpt_config.yaml b/examples/gpt3/gpt_config.yaml new file mode 100644 index 0000000000..8ef4f41e96 --- /dev/null +++ b/examples/gpt3/gpt_config.yaml @@ -0,0 +1,301 @@ +# WARNING: Yaml configs is currently an experimental feature +language_model: + # model architecture + num_layers: 24 + hidden_size: 1024 + num_attention_heads: 16 + num_query_groups: null + + ffn_hidden_size: null + kv_channels: null + hidden_dropout: 0.0 + attention_dropout: 0.0 + fp32_residual_connection: False + + apply_residual_connection_post_layernorm: False + layernorm_epsilon: 1.e-5 + layernorm_zero_centered_gamma: True + add_bias_linear: False + bias_activation_fusion: False + add_qkv_bias: False + gated_linear_unit: False + activation_func: swiglu + num_moe_experts: null + rotary_interleaved: False + window_size: null + + # initialization + init_method: null + init_method_std: 0.02 + output_layer_init_method: null + + # mixed-precision + apply_query_key_layer_scaling: False + attention_softmax_in_fp32: False + + # fusion + bias_swiglu_fusion: True + masked_softmax_fusion: True + persist_layer_norm: False + memory_efficient_layer_norm: False + bias_dropout_fusion: True + apply_rope_fusion: True + + # activation recomputation + recompute_granularity: null + recompute_method: null + recompute_num_layers: null + distribute_saved_activations: null + + # fp8 related + fp8: null + fp8_margin: 0 + fp8_interval: 1 + fp8_amax_history_len: 1 + fp8_amax_compute_algo: "most_recent" + fp8_wgrad: True + + # miscellaneous + clone_scatter_output_in_embedding: True + + normalization: "LayerNorm" # alt value supported by TE: "RMSNorm" + + # MoE related + moe_router_load_balancing_type: "aux_loss" + moe_router_topk: 2 + moe_router_group_topk: null + moe_router_num_groups: null + moe_grouped_gemm: False + moe_aux_loss_coeff: 0 # 1e-2 would be a good start value for load balance loss. + moe_z_loss_coeff: null # 1e-3 would be a good start value for z-loss + moe_input_jitter_eps: null + moe_token_dropping: False + +model_parallel: + # Model parallelism + tensor_model_parallel_size: 1 + context_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + sequence_parallel: True + expert_model_parallel_size: 1 + + # Initialization + perform_initialization: True + use_cpu_initialization: null + + # Training + fp16: False + bf16: True + params_dtype: null # Set from above arguments for core + timers: null + + # Optimizations + gradient_accumulation_fusion: True + async_tensor_model_parallel_allreduce: True + tp_comm_overlap: False + + # Debug Options + tp_comm_split_ag: True + tp_comm_atomic_ag: True + tp_comm_split_rs: True + tp_comm_atomic_rs: True + tp_comm_bulk_wgrad: True + tp_comm_bulk_dgrad: True + + # Parallelism + finalize_model_grads_func: null + + # Pipeline Parallel + pipeline_dtype: null + grad_scale_func: null + enable_autocast: False + autocast_dtype: null + variable_seq_lengths: False + num_microbatches_with_partial_activation_checkpoints: null + overlap_p2p_comm: False + batch_p2p_comm: True + batch_p2p_sync: True + use_ring_exchange_p2p: False + deallocate_pipeline_outputs: False + no_sync_func: null + grad_sync_func: null + param_sync_func: null + pipeline_model_parallel_split_rank: null + + # CPU Offloading + cpu_offloading: False + cpu_offloading_num_layers: 0 + _cpu_offloading_context: null + cpu_offloading_weights: False + cpu_offloading_activations: True + + # Timing + barrier_with_L1_time: True + +# training: +use_legacy_models: False +spec: null +micro_batch_size: 2 +global_batch_size: 128 +rampup_batch_size: [32, 32, 65324160] +check_for_nan_in_loss_and_grad: True +num_layers_per_virtual_pipeline_stage: null + +encoder_num_layers: null +decoder_num_layers: null +rotary_seq_len_interpolation_factor: null +add_position_embedding: False +make_vocab_size_divisible_by: 128 +group_query_attention: False + + +exit_signal_handler: False +exit_duration_in_mins: null +exit_interval: null + +untie_embeddings_and_output_weights: True +position_embedding_type: rope +rotary_percent: 0.5 +openai_gelu: False +squared_relu: False +swiglu: True +onnx_safe: null +bert_binary_head: True +max_position_embeddings: 4096 + +transformer_impl: local +use_flash_attn: False +seed: 1234 +data_parallel_random_init: False + +# Optimizer +optimizer: adam +lr: 2.5e-4 +lr_decay_style: cosine +lr_decay_iters: null +lr_decay_samples: 255126953 +lr_warmup_fraction: null +lr_warmup_iters: 0 +lr_warmup_samples: 81381 +lr_warmup_init: 0.0 +min_lr: 2.5e-5 +weight_decay: 0.1 +start_weight_decay: null +end_weight_decay: null +weight_decay_incr_style: constant +clip_grad: 1.0 +adam_beta1: 0.9 +adam_beta2: 0.95 +adam_eps: 1.e-08 +sgd_momentum: 0.9 +override_opt_param_scheduler: False +use_checkpoint_opt_param_scheduler: False + +# checkpointing arguments +save: null +save_interval: 20000 +no_save_optim: null +no_save_rng: null +load: null +no_load_optim: null +no_load_rng: null +finetune: False +use_checkpoint_args: False +exit_on_missing_checkpoint: False + +# loss arguments +loss_scale: null +initial_loss_scale: 4294967296 +min_loss_scale: 1.0 +loss_scale_window: 1000 +hysteresis: 2 +accumulate_allreduce_grads_in_fp32: False +fp16_lm_cross_entropy: False + +# distributed arguments +distributed_backend: nccl +distributed_timeout_minutes: 10 +overlap_grad_reduce: False +align_grad_reduce: True +overlap_param_gather: False +align_param_gather: False +scatter_gather_tensors_in_pipeline: True +local_rank: null +lazy_mpu_init: null +empty_unused_memory_level: 0 +standalone_embedding_stage: False +use_distributed_optimizer: False +nccl_communicator_config_path: null + +train_iters: null +eval_iters: 32 +eval_interval: 2000 +skip_train: False + +adlr_autoresume: False +adlr_autoresume_interval: 1000 + +# garbage collection +manual_gc: False +manual_gc_interval: 0 +manual_gc_eval: True + +tp_comm_overlap_cfg: null + +#data +data_path: null +split: '99,1,0' +train_data_path: null +valid_data_path: null +test_data_path: null +data_cache_path: null +mock_data: False +vocab_size: null +vocab_file: null +merge_file: null +vocab_extra_ids: 0 +seq_length: 4096 +encoder_seq_length: null +decoder_seq_length: null +retriever_seq_length: 256 +sample_rate: 1.0 +mask_prob: 0.15 +short_seq_prob: 0.1 +num_workers: 2 +tokenizer_type: GPTSentencePieceTokenizer +tokenizer_model: null +reset_position_ids: False +reset_attention_mask: False +eod_mask_loss: False +train_samples: 268554688 +dataloader_type: null + +#profile: +profile: False +profile_ranks: [0] +profile_step_end: 12 +profile_step_start: 10 + +#logging: +log_params_norm: True +log_num_zeros_in_grad: True +log_throughput: False +log_progress: False +timing_log_level: 0 +timing_log_option: minmax +tensorboard_log_interval: 1 +tensorboard_queue_size: 1000 +log_timers_to_tensorboard: False +log_validation_ppl_to_tensorboard: False +log_memory_to_tensorboard: False +log_world_size_to_tensorboard: False +log_loss_scale_to_tensorboard: True +wandb_project: '' +wandb_exp_name: '' +wandb_save_dir: '' +enable_one_logger: True +one_logger_project: megatron-lm +one_logger_run_name: null +log_interval: 100 +tensorboard_dir: null diff --git a/examples/gpt3/train_gpt3_175b_distributed.sh b/examples/gpt3/train_gpt3_175b_distributed.sh index 01ca2e0309..7d2c01b315 100755 --- a/examples/gpt3/train_gpt3_175b_distributed.sh +++ b/examples/gpt3/train_gpt3_175b_distributed.sh @@ -12,11 +12,11 @@ NUM_NODES=1 NODE_RANK=0 WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) -CHECKPOINT_PATH=$0 # -TENSORBOARD_LOGS_PATH=$1 # -VOCAB_FILE=$2 #/gpt2-vocab.json -MERGE_FILE=$3 #/gpt2-merges.txt -DATA_PATH=$4 #_text_document +CHECKPOINT_PATH=$1 # +TENSORBOARD_LOGS_PATH=$2 # +VOCAB_FILE=$3 #/gpt2-vocab.json +MERGE_FILE=$4 #/gpt2-merges.txt +DATA_PATH=$5 #_text_document DISTRIBUTED_ARGS=( --nproc_per_node $GPUS_PER_NODE @@ -31,6 +31,7 @@ GPT_MODEL_ARGS=( --num-attention-heads 96 --seq-length 2048 --max-position-embeddings 2048 + --attention-backend auto # Can use (flash/fused/unfused/local) ) TRAINING_ARGS=( @@ -49,7 +50,6 @@ TRAINING_ARGS=( --min-lr 6.0e-6 --lr-warmup-fraction .001 --lr-decay-iters 430000 - --use-mcore-models ) MODEL_PARALLEL_ARGS=( diff --git a/examples/inference/README.md b/examples/inference/README.md new file mode 100644 index 0000000000..b4b07cbc6a --- /dev/null +++ b/examples/inference/README.md @@ -0,0 +1,279 @@ +### Megatron Core Inference Documentation +This guide provides an example for Megatron Core for running model inference. + +### Contents +- [Megatron Core Inference Documentation](#megatron-core-inference-documentation) +- [Contents](#contents) + - [1. Quick Start](#1-quick-start) + - [1.1 Understanding The Code](#11-understanding-the-code) + - [1.2 Running The Code](#12-running-the-code) + - [2. Flow of Control In MCore Backend](#2-flow-of-control-in-mcore-backend) + - [3. Customizing The Inference Pipeline](#3-customizing-the-inference-pipeline) + - [3.1. Create Your Own Inference Backend](#31-create-your-own-inference-backend) + - [3.2. Create Your Own Text Generation Controller](#32-create-your-own-text-generation-controller) + - [3.3. Support Other Models](#33-support-other-models) + - [3.3. Modify Inference Parameters](#33-modify-inference-parameters) + - [4. Future work](#4-future-work) + +
+ +#### 1. Quick Start +This example runs batch inference on a GPT model trained using Megatron Core. The entrypoint is [simple_gpt_batch_inference.py](./gpt/gpt_batch_inference.py) + +
+ +##### 1.1 Code Walkthrough +***STEP 1 - Initialize model parallel and other default arguments*** +The micro batch size is set as 1 as it is not used in tensor-parallelism only, and for pipeline-parallel models it is calculated at runtime. +```python + initialize_megatron( + args_defaults={'no_load_rng': True, 'no_load_optim': True, 'micro_batch_size': 1} + ) +``` + +***STEP 2 - Load the model using the model_provider_function*** +NOTE: The model provider function supports both MCore and Legacy models. + +```python + model = get_model(model_provider, wrap_with_ddp=False) + load_checkpoint(model, None, None) + model = model[0] +``` + +***STEP 3 - Choose an engine*** +Text generation requires an inference engine, which includes a scheduler. The default engine is the [Megatron Core engine](../../megatron/core/inference/engine/mcore_engine.py) with a simple [text generation controller](../../megatron/core/inference/text_generation_controllers/text_generation_controller.py). TRTLLMEngine will be supported in the future. +```python + inference_wrapped_model = GPTInferenceWrapper(model, args) + text_generation_controller = TextGenerationController( + inference_wrapped_model=inference_wrapped_model, + tokenizer=tokenizer + ) + inference_backend = MCoreEngine( + text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size + ) +``` + +***STEP 4 - Run text generation*** +The [SamplingParams](../../megatron/core/inference/sampling_params.py) contains suggested defaults. Customize this to change top_p, top_k, number of tokens to generate etc. +*Note: The result is returned as a list of [InferenceRequests](../../megatron/core/inference/inference_request.py)* +```python + results: List[InferenceRequest] = inference_engine.generate( + prompts=args.prompts, sampling_params=sampling_params + ) + + if torch.distributed.get_rank() == 0: + for idx, result in enumerate(results): + print(f' ------------- RESULT FOR PROMPT {idx} --------------- ') + result = { + 'id': result.request_id, + 'input_prompt': result.prompt, + 'generated_text': result.generated_text, + 'generated_tokens' : result.generated_tokens + } + print(result) +``` + +
+ +##### 1.2 Running The Code +An example run script is shown below. Set the tokenizer paths, inference params, and other settings appropriately. + +For a quick recap on sampling parameters, refer to [this blog](https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910). + +``` +# In a slurm cluster (You could also use docker) +ACCOUNT= +MLM_PATH=/path/to/megatron-lm +GPT_CKPT=/path/to/gpt/ckpt +VOCAB_MERGE_FILE_PATH=/path/to/vocab/and/merge/file +CONTAINER_IMAGE=nvcr.io/ea-bignlp/ga-participants/nemofw-training:23.11 + +srun --account $ACCOUNT \ +--job-name=$ACCOUNT:inference \ +--partition=batch \ +--time=01:00:00 \ +--container-image $CONTAINER_IMAGE \ +--container-mounts $MLM_PATH:/workspace/megatron-lm/,$GPT_CKPT:/workspace/mcore_gpt_ckpt,$VOCAB_MERGE_FILE_PATH:/workspace/tokenizer \ +--no-container-mount-home \ +--pty /bin/bash \ + +# Inside the container run the following. + +cd megatron-lm/ +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +TOKENIZER_ARGS=( + --vocab-file /workspace/tokenizer/gpt2-vocab.json + --merge-file /workspace/tokenizer/gpt2-merges.txt + --tokenizer-type GPT2BPETokenizer +) + +MODEL_ARGS=( + --use-checkpoint-args + --use-mcore-models + --load /workspace/mcore_gpt_ckpt +) + +INFERENCE_SPECIFIC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --num-tokens-to-generate 20 + --max-batch-size 4 +) + +torchrun --nproc-per-node=4 examples/inference/gpt/simple_gpt_batch_inference.py \ + ${TOKENIZER_ARGS[@]} \ + ${MODEL_ARGS[@]} \ + ${INFERENCE_SPECIFIC_ARGS[@]} \ + --prompts "prompt one " "sample prompt two" "sample prompt 3" + +NOTE: Other parameters which can be customized for inference are :- +--temperature (Sampling temperature) +--top_k (top_k sampling) +--top_p (top_p sampling) +--num-tokens-to-generate (Number of tokens to generate for each prompt) +--inference-batch-times-seqlen-threshold (During inference, if batch-size times sequence-length is smaller than this threshold then we will not use pipelining, otherwise we will.') +--use-dist-ckpt (If using dist checkpoint format for the model) +--use-legacy-models (If using legacy gpt model instead of mcore gpt model) + +``` + + +
+ + +#### 2. Control Flow in the MCore Backend +An example of inference with static batching is provided in [gpt_batch_inference.py](./gpt/gpt_batch_inference.py). +* [mcore_engine](../../megatron/core/inference/engines/mcore_engine.py) **generate()** function is called with the input prompts. +* The `Scheduler` in the engine will add these prompts to the [active requests] pool (../../megatron/core/inference/inference_request.py) until max batch size is hit. Remaining requests will be added to the waiting requests pool. +* The engine will run until all requests (waiting + active) are completed. + * The active requests are passed into **generate_all_output_tokens_static_batch()** of the text generation controller . + * This function uses the **prep_model_for_inference()** method of the [model_inference_wrappers](../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) and runs an autoregressive sampling loop + * In the autoregressive loop, the **get_batch_for_context_window()** method of the inference wrapper is called to slice out the input tokens and masks + * Input tokens and masks are passed it into the **run_one_forward_step()** method, which calls the model `.forward()` method to get the output logits + * Output logits are synchronized across all pipeline parallel ranks + * The text generation controller obtains the log probabilities and samples tokens based on the strategy defined in the sampling parameters. + * The sampled tokens are then appended to the input prompt tokens for the next iteration + * The **update_generation_status()** method of the text generation controller checks which prompts have finished generating or hit a stop condition + * After the inference loop, the result is detokenized and stored as an attribute of the InferenceRequest. These requests are marked as completed. + * The **update_requests_pool()** method of the scheduler moves completed requests into the completed request pool and waiting requests into the active request pool + +
+ +#### 3. Customizing The Inference Pipeline + +The inference pipeline supports three levels of customization: + +* **Inference engine** - The MCore Engine is currently supported. Change this to add a new backend. +* **Text generation controller** - The main sampling loop. This can be customized to support alternative tokenization, detokenization, or to implement a new sampling strategy. +* **Inference Wrapped Model** - Change this to support a new model. +* **Modify Inference Parameters** - Change this to update top_p, top_k, number of tokens to be generated, temperature, or other sampling parameters. + +
+ +##### 3.1. Create Your Own Inference Backend +The [abstract_engine.py](./../../megatron/core/inference/engine/abstract_engine.py) file contains a `generate` method that can be extended to support a new backend. + +```python +class AbstractEngine(ABC): + @staticmethod + def generate(self) -> dict: + """The abstract backend's generate function. + + To define a new backend, implement this method and return the outputs as a dictionary. +``` + +
+ +##### 3.2. Implement a new Sampling Loop + +The [TextGenerationController](../../megatron/core/inference/text_generation_controllers/text_generation_controller.py) contains the main sampling loop and can be modified to support new tokenization, detokenization, or sampling strategies. + +``` python +class TextGenerationController: + + def tokenize_prompt(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: + """Utility to tokenize the input prompts""" + + def sample_from_logits( + self, + last_token_logits: torch.Tensor, + sampling_params: SamplingParams, + vocab_size: int, + ) -> torch.Tensor: + """Samples the logits to generate outputs + + Given the logits of the last token, this function samples according to the parameters defined in sampling_params and returns the sampled tokens. + """ + + def update_generation_status( + self, + updated_prompts_tokens: torch.Tensor, + generation_started: torch.Tensor, + current_context_end_position: int, + is_generation_done_tensor: torch.Tensor, + generated_sequence_lengths: torch.Tensor, + ) -> torch.Tensor: + """Function to check which prompts have reached an end condition + + We check which prompts have reached an end condition and set the corresponding flags of the is_generation_done_tensor to True . The generated sequence lengths increases as we keep generating, until that prompts hits an eod condition. The generation started status tensor helps us determine which prompts have started generating + """ + + def generate_all_output_tokens_static_batch( + self, active_requests: OrderedDict[int, InferenceRequest], + ) -> OrderedDict[int, InferenceRequest]: + """Utility to generate all the output tokens and probabilities for the prompts . + + This utility generates the output tokens for a static batch. It runs the forward steps till all prompts complete generation, updates the status of these requests to completed, adds the generated result and returns these requests + """ + + def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str: + """Detokenize the output generations""" +``` + +
+ +##### 3.3. Support Other Models +Extend [abstract_model_inference_wrapper.py](./../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) to support other models. The abstract model wrapper implements: +* Forward method which calls the model `forward` method depending on model parallel settings +* Initializes the model and puts it in `.eval()` mode +* Setup for the input parameters (max batch size, max seq length) + +The following methods should be implemented: +```python +class AbstractModelInferenceWrapper: + def prep_model_for_inference(self, prompts_tokens: torch.Tensor): + """A utility function for preparing model for inference + + The function gets called once before the auto regressive inference loop. It puts the model in eval mode , and gets some model and inference data parameters. Extend this to build position ids ,attention mask etc, so that required slices can be extracted during the forward pass + """ + + @abc.abstractclassmethod + def get_batch_for_context_window(self) -> List: + """Returns the input data for inference + + This function gets called iteratively in the inference loop. It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference. +``` + +Refer to [gpt_inference_wrapper.py](../../megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py) for an example of implementing this for GPTModel. + +
+ +##### 3.3. Modify Inference Parameters +We use [common inference params](../../megatron/core/inference/sampling_params.py) for text generation. Customize this if you want to change top_p, top_k, number of tokens to generate etc. If you want to add other attributes that you would use in the inference loop, you can do that as shown below + +``` +from megatron.core.inference.sampling_params import SamplingParams + +c = SamplingParams(temperature=0.5) +c.add_attributes({'min_length':4, 'eod_id':153}) +``` + +
+ +#### 4. Future work +The following features are planned for the future releases. +* Dynamic batching +* Paged Attention +* TRTLLM Engine support +* Support for multimodal inference \ No newline at end of file diff --git a/examples/inference/gpt/gpt_batch_inference.py b/examples/inference/gpt/gpt_batch_inference.py new file mode 100644 index 0000000000..9c2bada4b6 --- /dev/null +++ b/examples/inference/gpt/gpt_batch_inference.py @@ -0,0 +1,200 @@ +import os +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from pretrain_gpt import model_provider +import torch +import sys +import time +import tqdm +import warnings +from argparse import Namespace +from megatron.core.inference.engines.abstract_engine import AbstractEngine +from megatron.core.inference.engines.mcore_engine import MCoreEngine +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( + GPTInferenceWrapper, +) +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) +from megatron.core.transformer.module import MegatronModule + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) +) + +from megatron.training import get_args +from megatron.training import get_tokenizer +from megatron.training.checkpointing import load_checkpoint +from megatron.core import mpu +from megatron.training.initialize import initialize_megatron +from megatron.training import get_model +import asyncio +from typing import AsyncIterator, List + + + +def add_text_generate_args(parser): + """Text generation arguments.""" + group = parser.add_argument_group(title='text generation') + + group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') + group.add_argument("--top_k", type=int, default=1, help='Top k sampling.') + group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') + group.add_argument( + "--return-log-probs", + action='store_true', + default=False, + help='Return the log probabilities of the final output tokens', + ) + group.add_argument( + "--num-tokens-to-generate", + type=int, + default=30, + help='Number of tokens to generate for each prompt', + ) + group.add_argument( + "--prompts", + metavar='N', + type=str, + nargs='+', + help='Input prompts with each prompt within quotes and seperated by space', + ) + group.add_argument( + "--max-batch-size", type=int, default=8, dest="inference_max_requests", + help='Max number of prompts to process at once' + ) + group.add_argument("--stream", action="store_true", default=False, help="Stream output tokens") + return parser + + +def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine: + """Utility to get the relevant backend for running inference + + This function will automatically chose the TRTLLMBackend when possible, and if not revert to Mcore backend if the user does not specify any backends. TRT LLM Backend is not implmented yet. + + Args: + args (Namespace): The user arguments parsed from command line + model (MegatronModule): The megatron model . + + Returns: + AbstractBackend: The chosen backend + """ + tokenizer = get_tokenizer() + + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=args.hidden_size, + inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold, + fp32_residual_connection=args.fp32_residual_connection, + params_dtype=args.params_dtype, + padded_vocab_size=args.padded_vocab_size, + inference_max_requests=args.inference_max_requests, + inference_max_seq_length=args.inference_max_seq_length, + ) + + inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config) + text_generation_controller = TextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer) + return MCoreEngine(text_generation_controller=text_generation_controller) + + +async def generate( + inference_engine: MCoreEngine, + sampling_params: SamplingParams, + prompts: List[str], +) -> List[InferenceRequest]: + async def collect_stream(prompt, request_id, stream_generator): + print(f"Request {request_id}: {prompt}", end="", flush=True) + prev_idx = 0 + async for output in stream_generator: + print(output.generated_text[prev_idx:], end="", flush=True) + prev_idx = len(output.generated_text) + print() + + request_ids: List[str] = [ + inference_engine.add_request( + prompt=prompt, inference_parameters=sampling_params, streaming=True + ) + for prompt in prompts + ] + stream_generators = [inference_engine.get_stream_generator(request_id) for request_id in request_ids] + + tasks = [ + asyncio.create_task(collect_stream(prompt, request_id, stream_generator)) + for (prompt, request_id, stream_generator) in zip(prompts, request_ids, stream_generators) + ] + + await inference_engine.run_engine_async() + await asyncio.gather(*tasks) + + results: List[InferenceRequest] = [ + inference_engine.scheduler.completed_request_pool[request_id] for request_id in request_ids + ] + + return results + +def main(): + """Main program.""" + + # Note: The default args passed here can be overwritten by using appropriate params (check arguments.py file) + # Micro batch size is not needed to be set by user. (It is calculated based on inference-batch-times-seqlen-threshold argument) + initialize_megatron( + extra_args_provider=add_text_generate_args, + args_defaults={ + 'no_load_rng': True, + 'no_load_optim': True, + 'micro_batch_size': 1, + 'exit_on_missing_checkpoint': True, + }, + ) + + # Set up model and load checkpoint + model = get_model(model_provider, wrap_with_ddp=False) + load_checkpoint(model, None, None) + model = model[0] + + args = get_args() + + inference_engine = get_inference_engine(args, model) + + sampling_params = SamplingParams( + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + return_log_probs=args.return_log_probs, + num_tokens_to_generate=args.num_tokens_to_generate, + ) + + if args.enable_cuda_graph: + print(f"Running warmup for CUDA graphs...") + inference_engine.generate( + prompts=args.prompts, sampling_params=sampling_params + ) + + start_time = time.perf_counter() + if args.stream: + results: List[InferenceRequest] = asyncio.run(generate(inference_engine, sampling_params, args.prompts)) + else: + results: List[InferenceRequest] = inference_engine.generate( + prompts=args.prompts, sampling_params=sampling_params, + ) + end_time = time.perf_counter() + latency = end_time - start_time + + if torch.distributed.get_rank() == 0: + for idx, result in enumerate(results): + print(f' \n------------- RESULT FOR PROMPT {idx} --------------- ') + result = { + 'id': result.request_id, + 'input_prompt': result.prompt, + 'generated_text': result.generated_text, + 'generated_tokens': result.generated_tokens, + 'latency': latency, + } + print(result) + + torch.distributed.destroy_process_group() + +if __name__ == "__main__": + main() diff --git a/examples/inference/llama_mistral/huggingface_reference.py b/examples/inference/llama_mistral/huggingface_reference.py new file mode 100644 index 0000000000..9d8f4465f6 --- /dev/null +++ b/examples/inference/llama_mistral/huggingface_reference.py @@ -0,0 +1,25 @@ +import argparse +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +# Set up argument parsing +parser = argparse.ArgumentParser(description="Script for text generation with a specific model and prompt.") +parser.add_argument('--prompt', type=str, required=True, help="Prompt text to use for text generation") +parser.add_argument('--model-path', type=str, required=True, help="Path to the Huggingface model checkpoint") + +# Parse command-line arguments +args = parser.parse_args() + +model_path = args.model_path +prompt = args.prompt + +config = AutoConfig.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, config=config) +model = AutoModelForCausalLM.from_pretrained(model_path, config=config).cuda() + +inputs = tokenizer(prompt, return_tensors="pt") +for key in inputs: + inputs[key] = inputs[key].cuda() +# top_k, top_p and do_sample are set for greedy argmax based sampling + +outputs = model.generate(**inputs, max_length=100, do_sample=False, top_p=0, top_k=0, temperature=1.0) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) \ No newline at end of file diff --git a/examples/inference/llama_mistral/run_text_generation_llama3.1.sh b/examples/inference/llama_mistral/run_text_generation_llama3.1.sh new file mode 100755 index 0000000000..06584f0917 --- /dev/null +++ b/examples/inference/llama_mistral/run_text_generation_llama3.1.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# This example will start serving the Llama3.1-8B model +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 + +DISTRIBUTED_ARGS="--nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 0.0.0.0 \ + --master_port 6000" + +# Ensure CHECKPOINT and TOKENIZER_MODEL are provided +if [ -z "$1" ] || [ -z "$2" ]; then + echo "Error: You must provide CHECKPOINT and TOKENIZER_MODEL as command-line arguments." + echo "Usage: $0 /path/to/checkpoint /path/to/tokenizer_model" + exit 1 +fi + +# Assign command-line arguments to variables +CHECKPOINT=$1 +TOKENIZER_MODEL=$2 + +pip install flask-restful + +torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ + --use-checkpoint-args \ + --disable-bias-linear \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model ${TOKENIZER_MODEL} \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 500000 \ + --use-rope-scaling \ + --use-rotary-position-embeddings \ + --swiglu \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 14336 \ + --load ${CHECKPOINT} \ + --num-attention-heads 32 \ + --max-position-embeddings 131072 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length 8192 diff --git a/examples/inference/llama_mistral/run_text_generation_llama3.sh b/examples/inference/llama_mistral/run_text_generation_llama3.sh new file mode 100755 index 0000000000..c5fc4103ab --- /dev/null +++ b/examples/inference/llama_mistral/run_text_generation_llama3.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# This example will start serving the Llama3-8B model +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 + +DISTRIBUTED_ARGS="--nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 0.0.0.0 \ + --master_port 6000" + +# Ensure CHECKPOINT and TOKENIZER_MODEL are provided +if [ -z "$1" ] || [ -z "$2" ]; then + echo "Error: You must provide CHECKPOINT and TOKENIZER_MODEL as command-line arguments." + echo "Usage: $0 /path/to/checkpoint /path/to/tokenizer_model" + exit 1 +fi + +# Assign command-line arguments to variables +CHECKPOINT=$1 +TOKENIZER_MODEL=$2 + +pip install flask-restful + +torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ + --use-checkpoint-args \ + --disable-bias-linear \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model ${TOKENIZER_MODEL} \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 500000 \ + --use-rotary-position-embeddings \ + --swiglu \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 14336 \ + --load ${CHECKPOINT} \ + --num-attention-heads 32 \ + --max-position-embeddings 8192 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length 8192 diff --git a/examples/inference/llama_mistral/run_text_generation_mistral.sh b/examples/inference/llama_mistral/run_text_generation_mistral.sh new file mode 100755 index 0000000000..4358fd494c --- /dev/null +++ b/examples/inference/llama_mistral/run_text_generation_mistral.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# This example will start serving the Mistral-7B-v0.3 model +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +DISTRIBUTED_ARGS="--nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 0.0.0.0 \ + --master_port 6000" + +# Ensure CHECKPOINT and TOKENIZER_MODEL are provided +if [ -z "$1" ] || [ -z "$2" ]; then + echo "Error: You must provide CHECKPOINT and TOKENIZER_MODEL as command-line arguments." + echo "Usage: $0 /path/to/checkpoint /path/to/tokenizer_model" + exit 1 +fi + +# Assign command-line arguments to variables +CHECKPOINT=$1 +TOKENIZER_MODEL=$2 + +pip install flask-restful + +torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model ${TOKENIZER_MODEL} \ + --use-checkpoint-args \ + --apply-layernorm-1p \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --use-flash-attn \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --ffn-hidden-size 14336 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --load ${CHECKPOINT} \ + --num-attention-heads 32 \ + --max-position-embeddings 4096 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length 4096 \ + --seed 101 diff --git a/examples/run_text_generation_server_345M.sh b/examples/inference/run_text_generation_server_345M.sh similarity index 92% rename from examples/run_text_generation_server_345M.sh rename to examples/inference/run_text_generation_server_345M.sh index a151b98467..e8e61adb16 100755 --- a/examples/run_text_generation_server_345M.sh +++ b/examples/inference/run_text_generation_server_345M.sh @@ -26,9 +26,6 @@ torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ --fp16 \ --micro-batch-size 1 \ --seq-length 1024 \ - --out-seq-length 1024 \ - --temperature 1.0 \ --vocab-file $VOCAB_FILE \ --merge-file $MERGE_FILE \ - --top_p 0.9 \ --seed 42 diff --git a/examples/run_text_generation_server_345M_8_tensor_parallel.sh b/examples/inference/run_text_generation_server_345M_8_tensor_parallel.sh similarity index 92% rename from examples/run_text_generation_server_345M_8_tensor_parallel.sh rename to examples/inference/run_text_generation_server_345M_8_tensor_parallel.sh index 027ab42172..368cec3b31 100755 --- a/examples/run_text_generation_server_345M_8_tensor_parallel.sh +++ b/examples/inference/run_text_generation_server_345M_8_tensor_parallel.sh @@ -24,9 +24,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_s --fp16 \ --micro-batch-size 1 \ --seq-length 1024 \ - --out-seq-length 1024 \ - --temperature 1.0 \ --vocab-file $VOCAB_FILE \ --merge-file $MERGE_FILE \ - --top_p 0.9 \ --seed 42 diff --git a/examples/inference/t5/simple_t5_batch_inference.py b/examples/inference/t5/simple_t5_batch_inference.py new file mode 100644 index 0000000000..b4226d7de0 --- /dev/null +++ b/examples/inference/t5/simple_t5_batch_inference.py @@ -0,0 +1,157 @@ +import os +import sys +from argparse import Namespace + +import torch + +import pretrain_t5 +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.engines.abstract_engine import AbstractEngine +from megatron.core.inference.engines.mcore_engine import MCoreEngine +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import ( + T5InferenceWrapper, +) +from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import ( + EncoderDecoderTextGenerationController, +) +from megatron.core.transformer.module import MegatronModule +from pretrain_t5 import model_provider + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) +) + +from typing import List + +from megatron.core import mpu +from megatron.training import get_args, get_model, get_tokenizer +from megatron.training.checkpointing import load_checkpoint +from megatron.training.initialize import initialize_megatron + + +def add_text_generate_args(parser): + """Text generation arguments.""" + group = parser.add_argument_group(title='text generation') + + group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') + group.add_argument("--top_k", type=int, default=1, help='Top k sampling.') + group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') + group.add_argument( + "--return-log-probs", + action='store_true', + default=False, + help='Return the log probabilities of the final output tokens', + ) + group.add_argument( + "--num-tokens-to-generate", + type=int, + default=30, + help='Number of tokens to generate for each prompt', + ) + group.add_argument( + "--encoder-prompts", + metavar='N', + type=str, + nargs='+', + help='Encoder input prompts with each prompt within quotes and seperated by space', + ) + group.add_argument( + "--max-batch-size", type=int, default=1, help='Max number of prompts to process at once' + ) + return parser + + +def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine: + """Utility to get the relevant backend for running inference + + This function will automatically chose the TRTLLMBackend when possible, and if not revert to Mcore backend if the user does not specify any backends. TRT LLM Backend is not implmented yet. + + Args: + args (Namespace): The user arguments parsed from command line + model (MegatronModule): The megatron model . + + Returns: + AbstractBackend: The chosen backend + """ + tokenizer = get_tokenizer() + + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=args.hidden_size, + inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold, + fp32_residual_connection=args.fp32_residual_connection, + params_dtype=args.params_dtype, + padded_vocab_size=args.padded_vocab_size, + ) + + inference_wrapped_model = T5InferenceWrapper(model, inference_wrapper_config) + text_generation_controller = EncoderDecoderTextGenerationController( + inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer + ) + return MCoreEngine( + text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size + ) + + +def main(): + """Main program.""" + + # Note: The default args passed here can be overwritten by using appropriate params (check arguments.py file) + # Micro batch size is not needed to be set by user. (It is calculated based on inference-batch-times-seqlen-threshold argument) + initialize_megatron( + extra_args_provider=add_text_generate_args, + args_defaults={ + 'no_load_rng': True, + 'no_load_optim': True, + 'micro_batch_size': 1, + 'exit_on_missing_checkpoint': True, + }, + ) + + # Set up model and load checkpoint + model = get_model(model_provider, wrap_with_ddp=False) + load_checkpoint(model, None, None) + model = model[0] + + args = get_args() + + inference_engine = get_inference_engine(args, model) + + sampling_params = SamplingParams( + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + return_log_probs=args.return_log_probs, + num_tokens_to_generate=args.num_tokens_to_generate, + ) + + tokenizer = get_tokenizer() + decoder_prompts = [""] * len( + args.encoder_prompts + ) # for T5, the prompt is provided as encoder input, hence decoder_prompts is empty + args.prompts = decoder_prompts + + results: List[InferenceRequest] = inference_engine.generate( + prompts=args.prompts, + add_BOS=True, + encoder_prompts=args.encoder_prompts, + sampling_params=sampling_params, + ) + + if torch.distributed.get_rank() == 0: + for idx, result in enumerate(results): + print(f' \n------------- RESULT FOR PROMPT {idx} --------------- ') + result = { + 'id': result.request_id, + 'input_prompt': result.prompt, + 'generated_text': result.generated_text, + 'generated_tokens': result.generated_tokens, + } + print(result) + + +if __name__ == "__main__": + main() diff --git a/examples/mamba/.gitignore b/examples/mamba/.gitignore new file mode 100644 index 0000000000..940f4797e4 --- /dev/null +++ b/examples/mamba/.gitignore @@ -0,0 +1,4 @@ +checkpoints/ +data-cache/ +tensorboard/ +triton-cache/ diff --git a/examples/mamba/Dockerfile b/examples/mamba/Dockerfile new file mode 100644 index 0000000000..2e194095b7 --- /dev/null +++ b/examples/mamba/Dockerfile @@ -0,0 +1,32 @@ +FROM nvcr.io/nvidia/pytorch:24.01-py3 + +RUN pip uninstall -y triton && \ + pip install triton==2.1.0 sentencepiece==0.1.99 flask-restful + +# The causal-conv1d and mamba-ssm packages below are built from scratch here +# (which takes significant time) because there are no wheels available on PyPI +# for these relatively newer versions of the packages that are compatible with +# the older NGC-variant PyTorch version (e.g. version 2.2.0.dev231106) that we +# are using (in the NGC base container). Generally, if the package is not +# compatible with the PyTorch version, then it will generate a Python import +# error. The package authors tend to only release wheels for new versions of +# these pacakges which are compatible with the versions of regular PyTorch and +# NGC-variant PyTorch that are newer at the time of release. So, to use newer +# versions of these packages with relatively older versions of the NGC PyTorch +# container, we tend to have to build the packages from scratch. + +RUN cd /tmp && \ + git clone https://github.com/Dao-AILab/causal-conv1d.git && \ + cd causal-conv1d && \ + git checkout v1.2.2.post1 && \ + CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install . && \ + cd .. && \ + rm -rf causal-conv1d + +RUN cd /tmp && \ + git clone https://github.com/state-spaces/mamba.git && \ + cd mamba && \ + git checkout v2.0.3 && \ + MAMBA_FORCE_BUILD=TRUE pip install . && \ + cd .. && \ + rm -rf mamba diff --git a/examples/mamba/README.md b/examples/mamba/README.md new file mode 100644 index 0000000000..f8f6d79683 --- /dev/null +++ b/examples/mamba/README.md @@ -0,0 +1,94 @@ +# Mamba-based Language Models + +## Introduction + +This document is an entrypoint into the code used for +[An Empirical Study of Mamba-based Language Models](https://arxiv.org/abs/2406.07887). + +We are releasing the parameters for some of the models described in that +technical report via +[HuggingFace](https://huggingface.co/collections/nvidia/ssms-666a362c5c3bb7e4a6bcfb9c). +The code in the `main` branch is no longer compatible with the `Mamba2-*` +checkpoints. You can load them using the +[fixed snapshot of the code used for the technical report](https://github.com/NVIDIA/Megatron-LM/tree/ssm/examples/mamba). + +## Installation + +Create and run a Docker container using the [Dockerfile](./Dockerfile). + +``` +docker build -t your_image_name:your_tag . +docker run --gpus all -it --rm \ + -v /path/to/megatron:/workspace/megatron \ + -v /path/to/dataset:/workspace/dataset \ + -v /path/to/checkpoints:/workspace/checkpoints \ + -w /workspace/megatron/examples/mamba \ + your_image_name:your_tag +``` + +## Train + +[`train.sh`](./train.sh) is an example pretraining script, showing how to run on +a single node. Select between 800M-scale and 8B-scale models by setting the +`MODEL_SCALE` variable. The 8B-scale hybrid model architecture is the same as +the one described in the technical report. + +## Text Generation + +Use [`run_text_gen_server_8b.sh`](./run_text_gen_server_8b.sh) to start a text +generation server using an 8B hybrid checkpoint. This is configured to run the +8B hybrid model described in the technical report, with tensor model parallel +set to 1. + +The arguments in the script will need to be changed if using a checkpoint with a +different model parallel configuration or other differences, such as model +architecture. For example, to run the 8B pure Mamba-2 model, change +`--hybrid-attention-ratio` and `--hybrid-mlp-ratio` to 0.0, or remove them. + +Use [`run_text_gen_server_8b_gpt3.sh`](./run_text_gen_server_8b_gpt3.sh) to start +a text generation server using the 8B reference Transformer checkpoint. + +## Checkpoint Formats + +For inference, the model must be configured to match the checkpoint file used, +including the hybrid layer configuration and model parallel configuration. + +If you need to convert a hybrid checkpoint file to a different tensor parallel +or pipeline parallel size, use +[the hybrid conversion script](../../tools/checkpoint/hybrid_conversion.py). +There is an example run command at the end of that file. + +Before running that script, you will need to set `PYTHONPATH` to include the +root directory of your Megatron-LM repository clone. + +``` +export PYTHONPATH=:PYTHONPATH +``` + +## Hybrid Options + +`--hybrid-attention-ratio ATT` specifies a target ratio of attention layers +to total layers. For example, 4 attention layers out of 48 total layers is +specified by `--hybrid-attention-ratio 0.08`. + +`--hybrid-mlp-ratio MLP` specifies a target ratio of MLP layers to total +layers. For example, 24 MLP layers out of 48 total layers is specified by +`--hybrid-mlp-ratio 0.5`. + +* (`ATT` + `MLP`) must be less than or equal to 1.0. +* (1.0 - `ATT` - `MLP`) is the hybrid mamba ratio, the ratio of mamba layers to +total layers. +* `ATT` = `MLP` = 0 is a pure Mamba model. +* `ATT` = `MLP` = 0.5 is a transfomer model. + +If either `ATT` or `MLP` is greater than 0.0 or if `--hybrid-override-pattern` +is specified, the logfile will include information about the hybrid layer +pattern used. `--hybrid-override-pattern` can be used to specify a different +pattern than the default, algorithmically-generated one. + +## Mamba vs Mamba-2 + +This codebase currently only supports Mamba-2, and not the original version of +Mamba. However, the +[fixed snapshot of the code used for the technical report](https://github.com/NVIDIA/Megatron-LM/tree/ssm/examples/mamba) +can be configured to run the original version of Mamba. diff --git a/examples/mamba/run_text_gen_server_8b.sh b/examples/mamba/run_text_gen_server_8b.sh new file mode 100755 index 0000000000..8d3137f244 --- /dev/null +++ b/examples/mamba/run_text_gen_server_8b.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +# Use: ./run_text_gen_server_8b.sh +# To launch the client: python ../../tools/text_generation_cli.py + +CHECKPOINT_PATH=$1 +TOKENIZER_PATH=$2 + +DISTRIBUTED_ARGS="--nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr localhost \ + --master_port 6000" + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_IB_TIMEOUT=19 +export NCCL_IB_QPS_PER_CONNECTION=4 + +export TRITON_CACHE_DIR="./triton-cache/" +export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" + +torchrun $DISTRIBUTED_ARGS ../../tools/run_mamba_text_generation_server.py \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --untie-embeddings-and-output-weights \ + --num-layers 56 \ + --hidden-size 4096 \ + --load ${CHECKPOINT_PATH} \ + --num-attention-heads 32 \ + --group-query-attention \ + --num-query-groups 8 \ + --hybrid-attention-ratio 0.08 \ + --hybrid-mlp-ratio 0.5 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --disable-bias-linear \ + --normalization RMSNorm \ + --seq-length 4096 \ + --max-position-embeddings 4096 \ + --position-embedding-type none \ + --tokenizer-type GPTSentencePieceTokenizer \ + --tokenizer-model ${TOKENIZER_PATH} \ + --distributed-backend nccl \ + --distributed-timeout-minutes 1440 \ + --bf16 \ + --micro-batch-size 1 \ + --use-mcore-models \ + --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ + --seed 42 diff --git a/examples/mamba/run_text_gen_server_8b_gpt3.sh b/examples/mamba/run_text_gen_server_8b_gpt3.sh new file mode 100644 index 0000000000..5413b245ed --- /dev/null +++ b/examples/mamba/run_text_gen_server_8b_gpt3.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# Use: ./run_text_gen_server_8b_gpt3.sh +# To launch the client: python ../../tools/text_generation_cli.py + +CHECKPOINT_PATH=$1 +TOKENIZER_PATH=$2 + +DISTRIBUTED_ARGS="--nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr localhost \ + --master_port 6000" + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_IB_TIMEOUT=19 +export NCCL_IB_QPS_PER_CONNECTION=4 + +torchrun $DISTRIBUTED_ARGS ../../tools/run_text_generation_server.py \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --use-flash-attn \ + --apply-layernorm-1p \ + --untie-embeddings-and-output-weights \ + --num-layers 32 \ + --hidden-size 4096 \ + --load ${CHECKPOINT_PATH} \ + --num-attention-heads 32 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --disable-bias-linear \ + --seq-length 4096 \ + --max-position-embeddings 4096 \ + --position-embedding-type rope \ + --rotary-percent 0.5 \ + --squared-relu \ + --tokenizer-type GPTSentencePieceTokenizer \ + --tokenizer-model ${TOKENIZER_PATH} \ + --distributed-backend nccl \ + --distributed-timeout-minutes 1440 \ + --bf16 \ + --micro-batch-size 1 \ + --use-mcore-models \ + --transformer-impl local \ + --seed 42 diff --git a/examples/mamba/train.sh b/examples/mamba/train.sh new file mode 100755 index 0000000000..3952a997d4 --- /dev/null +++ b/examples/mamba/train.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +# Use: ./train.sh + +MODEL_SCALE="800M" # or "8B" + +case "${MODEL_SCALE}" in + "800M") + TENSOR_MODEL_PARALLEL_SIZE=1 + NUM_LAYERS=48 + HIDDEN_SIZE=1024 + NUM_ATTENTION_HEADS=16 + GLOBAL_BATCH_SIZE=32 + ;; + "8B") + TENSOR_MODEL_PARALLEL_SIZE=4 + NUM_LAYERS=56 + HIDDEN_SIZE=4096 + NUM_ATTENTION_HEADS=32 + GLOBAL_BATCH_SIZE=8 + ;; + *) + echo "Invalid version specified" + exit 1 + ;; +esac + +DATA_PATH=$1 +TOKENIZER_PATH=$2 + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_IB_TIMEOUT=19 +export NCCL_IB_QPS_PER_CONNECTION=4 + +CHECKPOINT_DIR="./checkpoints" +DATACACHE_DIR="./data-cache" +TENSORBOARD_DIR="./tensorboard" + +mkdir -p ${CHECKPOINT_DIR} +mkdir -p ${DATACACHE_DIR} +mkdir -p ${TENSORBOARD_DIR} + +export TRITON_CACHE_DIR="./triton-cache/" +export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" + +SEQ_LEN=4096 +TRAIN_SAMPLES=73242188 # 300B tokens / 4096 +LR_WARMUP_SAMPLES=50000 +LR_DECAY_SAMPLES=73192188 # TRAIN_SAMPLES - LR_WARMUP_SAMPLES + +options=" \ + --tensor-model-parallel-size ${TENSOR_MODEL_PARALLEL_SIZE} \ + --sequence-parallel \ + --pipeline-model-parallel-size 1 \ + --use-distributed-optimizer \ + --overlap-param-gather \ + --overlap-grad-reduce \ + --untie-embeddings-and-output-weights \ + --init-method-std 0.02 \ + --position-embedding-type none \ + --num-layers ${NUM_LAYERS} \ + --hidden-size ${HIDDEN_SIZE} \ + --num-attention-heads ${NUM_ATTENTION_HEADS} \ + --group-query-attention \ + --num-query-groups 8 \ + --hybrid-attention-ratio 0.08 \ + --hybrid-mlp-ratio 0.5 \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${SEQ_LEN} \ + --train-samples ${TRAIN_SAMPLES} \ + --lr-warmup-samples ${LR_WARMUP_SAMPLES} \ + --lr-decay-samples ${LR_DECAY_SAMPLES} \ + --save ${CHECKPOINT_DIR} \ + --load ${CHECKPOINT_DIR} \ + --data-path ${DATA_PATH} \ + --data-cache-path ${DATACACHE_DIR} \ + --split 99,1,0 \ + --tokenizer-type GPTSentencePieceTokenizer \ + --tokenizer-model ${TOKENIZER_PATH} \ + --distributed-backend nccl \ + --micro-batch-size 4 \ + --global-batch-size ${GLOBAL_BATCH_SIZE} \ + --lr 2.5e-4 \ + --min-lr 2.5e-5 \ + --lr-decay-style cosine \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --disable-bias-linear \ + --normalization RMSNorm \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --log-interval 10 \ + --save-interval 2000 \ + --eval-interval 2000 \ + --eval-iters 32 \ + --bf16 \ + --use-mcore-models \ + --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ + --no-create-attention-mask-in-dataloader \ + --tensorboard-dir ${TENSORBOARD_DIR}" + +torchrun --nproc_per_node 8 ../../pretrain_mamba.py ${options} diff --git a/examples/merge_mp_bert.sh b/examples/merge_mp_bert.sh deleted file mode 100755 index 1383433284..0000000000 --- a/examples/merge_mp_bert.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -TENSOR_MODEL_PARALLEL_SIZE=2 - -VOCAB_FILE=bert-vocab.txt -CHECKPOINT_PATH=checkpoints/bert_345m - -WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ - --model-type BERT \ - --tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \ - --tokenizer-type BertWordPieceLowerCase \ - --vocab-file $VOCAB_FILE \ - --num-layers 24 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --seq-length 512 \ - --max-position-embeddings 512 \ - --load $CHECKPOINT_PATH diff --git a/examples/mixtral/README.md b/examples/mixtral/README.md new file mode 100644 index 0000000000..e85eccd6ef --- /dev/null +++ b/examples/mixtral/README.md @@ -0,0 +1,132 @@ +# Mixtral 8x7B Model Inference and Finetuning + +## Download Mixtral 8x7B Checkpoints +Download Mixtral 8x7B HF format checkpoint from [HF-hub](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/) + +Or you can simply run this following script to download Mixtral 8x7B into a specific folder. +```python +from huggingface_hub import snapshot_download +SAVED_DIR = "" # Specify the saved directory +# Download HF checkpoints +snapshot_download(repo_id="mistralai/Mixtral-8x7B-v0.1", ignore_patterns=["*.pt"], local_dir=SAVED_DIR, local_dir_use_symlinks=False) +``` + +## Convert Mixtral 8x7B checkpoints from HF to MCore +The HF checkpoints can be converted to Megatron format by using the provided checkpoint converter for HF format. +The target model parallel size(e.g. TP,PP,EP) should be specified. + +Currently the converter doesn't support distributed checkpointing yet, so each different parallel config requires a specific checkpoint. +- For training, the recommended model parallel config is TP1EP8PP4 +- For inference, the recommended model parallel config is TP1EP1PP2 + +``` +TOKENIZER_MODEL=/workspace/checkpoints/mixtral-hf/tokenizer.model +MEGATRON_PATH="/workspace/megatron-lm" +export PYTHONPATH=$MEGATRON_PATH:$PYTHONPATH +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +TARGET_TP_SIZE="" +TARGET_EP_SIZE="" +TARGET_PP_SIZE="" + +HF_FORMAT_DIR=/workspace/checkpoints/mixtral-hf +MEGATRON_FORMAT_DIR=/workspace/checkpoints/mixtral-mcore-TP${TARGET_TP_SIZE}PP${TARGET_PP_SIZE}EP${TARGET_EP_SIZE} + +python tools/checkpoint/convert.py \ +--model-type GPT \ +--loader loader_mixtral_hf \ +--saver mcore \ +--target-tensor-parallel-size ${TARGET_TP_SIZE} \ +--target-pipeline-parallel-size ${TARGET_PP_SIZE} \ +--target-expert-parallel-size ${TARGET_EP_SIZE} \ +--load-dir ${HF_FORMAT_DIR} \ +--save-dir ${MEGATRON_FORMAT_DIR} \ +--tokenizer-model ${TOKENIZER_MODEL} +``` + +## Text generation with Mixtral 8x7B +Inference with Mixtral 8x7B requires at least 2 GPUS, such that a distributed checkpoint with EP>=2 or PP>=2 converted with above script is needed. + +The Megatron-LM have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`, launch it with the following script: +``` +#!/bin/bash +# This example will start serving the Mixtral 8x7B model. +DISTRIBUTED_ARGS="--nproc_per_node 2 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr localhost \ + --master_port 6000" + +CHECKPOINT= +TOKENIZER_MODEL= + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +pip install flask-restful + +torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 2 \ + --expert-model-parallel-size 1 \ + --load ${CHECKPOINT} \ + --tokenizer-type Llama2Tokenizer \ + --tokenizer-model $TOKENIZER_MODEL \ + --use-mcore-models \ + --max-position-embeddings 32768 \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 14336 \ + --num-attention-heads 32 \ + --normalization RMSNorm \ + --disable-bias-linear \ + --position-embedding-type rope \ + --no-position-embedding \ + --swiglu \ + --untie-embeddings-and-output-weights \ + --group-query-attention \ + --num-query-groups 8 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length 1024 \ + --seed 42 \ + --num-experts 8 \ + --moe-router-topk 2 \ + --moe-token-dispatcher-type alltoall \ + --moe-grouped-gemm \ + --mock-data \ + --rotary-base 1000000 +``` + +Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on. + +``` +python tools/text_generation_cli.py localhost:5000 +``` + + +## Finetuning from pretrained Mixtral 8x7B +To finetuning pretrained Mixtral 8x7B, use the following scripts: + + +```bash +PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.04-py3 +CHECKPOINT_PATH="" # Speicfy path to checkpoint dir +TOKENIZER_MODEL="" # Specify path to tokenizer.model +DATA_PATH="" # Specify path to data + +docker run \ + --gpus=all \ + --ipc=host \ + --workdir /workspace/megatron-lm \ + -v /path/to/data:/path/to/data \ + -v /path/to/megatron-lm:/workspace/megatron-lm \ + $PYTORCH_IMAGE \ + bash examples/mixtral/train_mixtral_8x7b_distributed.sh $CHECKPOINT_PATH $TOKENIZER_MODEL $DATA_PATH +``` + +The above functionality also applys to Mixtral 8x22B actually, you should set the model config (including hidden_size/head_num/num_layers/ffn_hidden_size) properly according to the original [config](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1/blob/main/config.json). + +## Acknowledgements +Contributors outside NVIDIA for the huggingface converter and example of Mixtral models in Megatron-Core: +- Peng Li +- Jun Huang diff --git a/examples/mixtral/train_mixtral_8x7b_distributed.sh b/examples/mixtral/train_mixtral_8x7b_distributed.sh new file mode 100644 index 0000000000..ed44d60f5c --- /dev/null +++ b/examples/mixtral/train_mixtral_8x7b_distributed.sh @@ -0,0 +1,116 @@ +#!/bin/bash + +# Runs Mixtral 8x7B model + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=${MASTER_ADDR:-"localhost"} +MASTER_PORT=${MASTER_PORT:-"6000"} +NNODES=${SLURM_NNODES:-"1"} +NODE_RANK=${RANK:-"0"} +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CHECKPOINT_PATH=$1 +TOKENIZER_MODEL=$2 +DATA_PATH=$3 + +DISTRIBUTED_ARGS=( + --nproc_per_node $GPUS_PER_NODE + --nnodes $NNODES + --node_rank $NODE_RANK + --master_addr $MASTER_ADDR + --master_port $MASTER_PORT +) + +MODEL_ARGS=( + --use-mcore-models + --disable-bias-linear + --seq-length 4096 + --max-position-embeddings 32768 + --num-layers 32 + --hidden-size 4096 + --ffn-hidden-size 14336 + --num-attention-heads 32 + --init-method-std 0.01 + --attention-dropout 0.0 + --hidden-dropout 0.0 + --normalization RMSNorm + --position-embedding-type rope + --swiglu + --untie-embeddings-and-output-weights + --group-query-attention + --num-query-groups 8 + --no-masked-softmax-fusion + --no-position-embedding + --rotary-base 1000000 +) + +MOE_ARGS=( + --num-experts 8 + --moe-router-topk 2 + --moe-router-load-balancing-type aux_loss + --moe-aux-loss-coeff 1e-2 + --moe-grouped-gemm + --moe-token-dispatcher-type alltoall + --overlap-param-gather + --overlap-grad-reduce +) + +DATA_ARGS=( + --tokenizer-type Llama2Tokenizer + --tokenizer-model ${TOKENIZER_MODEL} + --data-path $DATA_PATH + --split 99990,8,2 +) + +TRAINING_ARGS=( + --micro-batch-size 1 + --global-batch-size 256 + --lr 1e-4 + --train-iters 500000 + --lr-decay-iters 320000 + --lr-decay-style cosine + --min-lr 1.0e-5 + --weight-decay 0.1 + --lr-warmup-iters 500 + --clip-grad 1.0 + --bf16 +) + +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size 1 + --pipeline-model-parallel-size 4 + --expert-model-parallel-size 8 + --use-distributed-optimizer + --sequence-parallel +) + +LOGGING_ARGS=( + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" \ + --no-load-optim \ + --no-load-rng +) + +if [ -n "${WANDB_API_KEY}" ]; then + LOGGING_ARGS+=( + --wandb-project ${WANDB_PROJECT:-"Mixtral"} + --wandb-exp-name ${WANDB_NAME:-"Mixtral_8x7B"} + ) +fi + + +torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \ + ${MODEL_ARGS[@]} \ + ${MOE_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${LOGGING_ARGS[@]} diff --git a/examples/multimodal/Dockerfile b/examples/multimodal/Dockerfile new file mode 100644 index 0000000000..7b54091ae6 --- /dev/null +++ b/examples/multimodal/Dockerfile @@ -0,0 +1,26 @@ +FROM nvcr.io/nvidia/pytorch:24.02-py3 + +RUN apt update && \ + apt -y upgrade && \ + apt install -y --no-install-recommends \ + software-properties-common \ + build-essential \ + python3-pip \ + python3-dev \ + bash \ + git \ + vim \ + tmux \ + python-is-python3 \ + default-jre + +RUN pip install --upgrade pip +RUN pip install einops einops-exts sentencepiece braceexpand webdataset packaging +RUN pip install transformers datasets accelerate timm +RUN pip install pytest-cov pytest_mock nltk wrapt +RUN pip install zarr "tensorstore==0.1.45" +RUN pip install black isort click==8.0.2 +RUN pip install pycocoevalcap megatron-energon mistral-common tiktoken +RUN pip install git+https://github.com/openai/CLIP.git +# Use --no-deps for the following to avoid outdated and unnecessary dependencies. +RUN pip install open_clip_torch open-flamingo[eval] --no-deps diff --git a/examples/multimodal/README.md b/examples/multimodal/README.md new file mode 100644 index 0000000000..a65839f8f1 --- /dev/null +++ b/examples/multimodal/README.md @@ -0,0 +1,157 @@ +# Multimodal Example + +*NOTE: This example is under active development and is expected change.* + +The following walks through all the steps required to pretrain and instruction tune a llava architecture vision-language model (VLM). It is important to precisely follow all steps to obtain the benchmark scores at the end. + +This example has been tested on an A100 based DGX cluster. Pretraining and instruction tuning took approximately 1 day and 11 hours respectively on 64 GPUs using four way tensor parallelism (tp=4). Training speed will scale approximately linearly with number of GPUs available. + +Multimodal support in megatron is still under active development. This example is not intended to produce state-of-the-art model quality (that would require more data and model refinements), it is merely intended to demonstrate the multimodal functionality in megatron. If you hit any problems, please open a github issue. + +## Setup + +### Docker container + +You can build a docker container using `examples/multimodal/Dockerfile` to run this example. + +### Language model + +Follow the instructions in [Mistral](../../docs/llama_mistral.md#mistral-7b) to download weights for Mistral-7B-Instruct-v0.3 from HuggingFace and convert to mcore format with tensor parallel size 4. +Please use the tokenizer from HuggingFace. + +### Vision model + +This example uses the OpenAI CLIP `ViT-L/14@336px` Vision model. To download the weights from OpenAI and convert them to a format that can be loaded in megatron, please run the following: + +``` +python examples/multimodal/model_converter/clip_converter.py --download-root /some/download/folder --output /some/output/folder --tensor-parallel-size 4 --use-te +``` + +### Combined model checkpoint + +Update the paths to point to the mcore converted CLIP and Mistral models and run the following script to combine the Mistral and CLIP models into a single multimodal checkpoint folder: + +``` +examples/multimodal/combine_lm_vision_checkpoints.sh /path/to/mistral/model /path/to/clip/model /output/dir +``` + +## Training + +### Pretraining + +1. Download the LLavA-Pretrain dataset from Hugging Face and unzip the images folder (NOTE: 79GB of disk space required): + + ``` + git clone https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain + cd LLaVA-Pretrain + unzip images.zip + ``` + +3. Run the following script to convert the data to webdataset format: + + ``` + cd + python examples/multimodal/convert_llava_pretrain_to_wds.py + ``` + +4. Run the following command to convert to megatron-energon format: + + ``` + cd /wds + energon prepare ./ + ``` + + select the following values for the presented options: + + ``` + > Please enter a desired train/val/test split like "0.5, 0.2, 0.3" or "8,1,1": 9,1,0 + > Do you want to create a dataset.yaml interactively? [Y/n]: Y + > Please enter a number to choose a class: 10 (VQAWebdataset) + > Do you want to set a simple field_map[Y] (or write your own sample_loader [n])? [Y/n]: Y + > Please enter a webdataset field name for 'image' (): jpg + > Please enter a webdataset field name for 'context' (): json[0][value] + > Please enter a webdataset field name for 'answers' (typing.Optional[typing.List[str]], default: None): json[1][value] + > Please enter a webdataset field name for 'answer_weights' (typing.Optional[torch.Tensor], default: None): + ``` + +5. Update `pretrain_dataset.yaml` so that both `path` variables point to `LLaVA-Pretrain/wds` + +6. Run the following script to pretrain a llava model for image captioning: + + ``` + cd + examples/multimodal/pretrain_mistral_clip.sh + ``` + +All being well you should observe training and validation loss curves similar to the following: + +Pretraining loss curves + +These curves were obtained with global batch size of 256. Changing this value will likely change the curves. For pretraining and instruction tuning llava models we have found that loss curves are an unreliable predictor of downstream task performance. Therefore it is necessary to run test generation and evaluation on a range of metrics to understand model quality. We intend to add training time zero-shot evaluation in a future update. + +You can execute the pretraining script multiple times to resume training. On resuming, the latest model, optimizer, and dataloader state are loaded. + +### SFT + +1. Prepare an instruction tuning dataset such in [megatron-energon format](https://nvidia.github.io/Megatron-Energon/data_prep.html#). NOTE: we do not provide instructions for this. + +2. Update `sft_dataset.yaml` so that both `path` variables point to the train and val splits of your instruction tuning dataset. + +Run the following script to instruction tune the pre-trained llava model: + + ``` + examples/multimodal/sft_mistral_clip.sh + ``` + +You can execute the SFT script multiple times to resume training. On resuming, the latest model, optimizer, and dataloader state are loaded. + +## Evaluation + +### Generation + +Run the following script: + +``` +examples/multimodal/text_generation_mistral_clip.sh --input-image-path /path/to/input/images --output-path /some/output/directory \ + --model-path /path/to/model.pt --gt-path /path/to/groundtruth/file --task generation-task-name +``` + +where `--task generation-task-name` is the name of the evaluation benchmark such as `captioning` or `MMMU`. + +### After pretraining + +#### COCO captioning + +1. Download the COCO 2014 test image set: + + ```wget http://images.cocodataset.org/zips/test2014.zip``` + +2. Download COCO test image annotations: + + ```https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json``` + +3. First, run text generation using `--task captioning`. + +4. Run the following command: + + ``` + python examples/multimodal/evaluate_coco.py --input-path /output/directory/from/generation --groundtruth-path /path/to/groundtruth/file + ``` + +For the mistral-7b-instruct plus clip llava model you should obtain a COCO CIDer score of approximately 94. + +### After SFT + +#### MMMU + +The official MMMU repository is not pip installable currently so please clone their code in `examples/multimodal` by running `git clone https://github.com/MMMU-Benchmark/MMMU.git`. + +The MMMU dataset is loaded from HuggingFace automatically as part of the code. + +Run text generation using `--task MMMU`. Then, run the following command: + +``` +python examples/multimodal/evaluate_mmmu.py --input-path /output/directory/from/generation +``` + +For the mistral-7b-instruct plus clip instruction tuned llava model you should obtain a MMMU score of approximately 38. diff --git a/examples/multimodal/assets/pretrain_curves.png b/examples/multimodal/assets/pretrain_curves.png new file mode 100644 index 0000000000..7981a73ba1 Binary files /dev/null and b/examples/multimodal/assets/pretrain_curves.png differ diff --git a/examples/multimodal/combine_lm_vision_checkpoints.sh b/examples/multimodal/combine_lm_vision_checkpoints.sh new file mode 100755 index 0000000000..52de16ecd2 --- /dev/null +++ b/examples/multimodal/combine_lm_vision_checkpoints.sh @@ -0,0 +1,57 @@ +#/bin/bash +MCORE_LM=$1 # +MCORE_VISION=$2 # +OUTPUT_DIR=$3 # +MODEL_TYPE=$4 # Model type. Default: Mistral CLIP example. + +if [[ $MODEL_TYPE == "nvlm" ]]; then + # NVLM TP=8 + python examples/multimodal/combine_state_dicts.py \ + --input \ + ${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_04/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_04/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_05/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_05/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_06/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_06/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_07/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_07/model_optim_rng.pt \ + --prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model \ + --output \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_04/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_05/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_06/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_07/model_optim_rng.pt +else + # Mistral CLIP example TP=4. + python examples/multimodal/combine_state_dicts.py \ + --input \ + ${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + --prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model \ + --output \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt +fi + +echo 1 > ${OUTPUT_DIR}/latest_checkpointed_iteration.txt diff --git a/examples/multimodal/combine_state_dicts.py b/examples/multimodal/combine_state_dicts.py new file mode 100644 index 0000000000..2f7028474c --- /dev/null +++ b/examples/multimodal/combine_state_dicts.py @@ -0,0 +1,81 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import argparse +import os +import sys + +import torch + +# Add megatron to the path. +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) +) + + +def combine(input_files, module_prefixes, output_files): + num_inputs_per_output = int(len(input_files) / len(output_files)) + + for output_idx, output_file in enumerate(output_files): + combined_state_dict = None + + lb = output_idx * num_inputs_per_output + ub = (output_idx + 1) * num_inputs_per_output + current_input_files = input_files[lb:ub] + current_module_prefixes = module_prefixes[lb:ub] + + for i, (input_file, module_prefix) in enumerate( + zip(current_input_files, current_module_prefixes) + ): + # initialize the combined state dict using the first provided input file + current_state_dict = torch.load(input_file) + if i == 0: + combined_state_dict = current_state_dict.copy() + combined_state_dict["model"] = dict() + + # copy model state dict and prefix names with the given module keys. + for k, v in current_state_dict["model"].items(): + combined_state_dict["model"]["%s.%s" % (module_prefix, k)] = v + + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + torch.save(combined_state_dict, output_file) + print("saved:", output_file) + + print("done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=""" + Combine multiple state dicts into a single state dict. + The combined state dict is first initialized by taking a copy of the first provided input state dict. + To avoid conflicts in model parameter names, a prefix must be provided for each input file. + Model parameter names will be renamed from to .. + + + Example usage: + python combine_state_dicts.py --input language_model.pt vision_model.pt --prefixes language_model vision_model --output multimodal.pt + """, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--input", nargs="*", required=True, help="paths to input state dict files") + parser.add_argument( + "--prefixes", + nargs="*", + required=True, + help="prefixes to use with each input model's parameters", + ) + parser.add_argument( + "--output", nargs="*", required=True, help="path(s) to output state dict file" + ) + + args = parser.parse_args() + + assert len(args.input) > 1, "must provide more than 1 input model to combine" + assert len(args.input) == len(args.prefixes), "each input model must have a corresponding key" + assert ( + len(args.input) % len(args.output) == 0 + ), "each output file must use the same number of input files" + + combine(args.input, args.prefixes, args.output) diff --git a/examples/multimodal/config.py b/examples/multimodal/config.py new file mode 100644 index 0000000000..e0de36f7a2 --- /dev/null +++ b/examples/multimodal/config.py @@ -0,0 +1,280 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass + +import torch + +from megatron.training.activations import fast_gelu, quick_gelu, squared_relu + + +def get_language_model_config(config): + if config.language_model_type == "llama3_8b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 14336 + elif config.language_model_type == "llama3.1_8b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 14336 + elif config.language_model_type == "llama3.1_70B": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 28672 + elif config.language_model_type == "mistral_7b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 14336 + elif config.language_model_type == "yi-34b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 20480 + elif config.language_model_type == "qwen2.5_7B": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.add_qkv_bias = True + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 18944 + elif config.language_model_type == "qwen2.0_72B": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.add_qkv_bias = True + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 29568 + elif config.language_model_type == "llama3.2_1b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 8192 + elif config.language_model_type.startswith("huggingface"): + # Loaded from HuggingFace config file. + pass + else: + raise ValueError(f"unknown language model type {config.language_model_type}") + + return config + + +def get_vision_model_config(config, apply_query_key_layer_scaling): + if config.vision_model_type == "clip": + config.num_layers = 24 + config.num_attention_heads = 16 + config.add_bias_linear = True + config.add_qkv_bias = True + config.hidden_size = 1024 + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + config.ffn_hidden_size = 4096 + config.gated_linear_unit = False + config.activation_func = quick_gelu + config.kv_channels = 64 + config.num_query_groups = 16 + config.layernorm_zero_centered_gamma = False + config.apply_query_key_layer_scaling = apply_query_key_layer_scaling + config.bias_activation_fusion = False + config.bias_dropout_fusion = False + config.attention_softmax_in_fp32 = True + config.normalization = 'LayerNorm' + config.apply_rope_fusion = False + elif config.vision_model_type == "siglip": + config.num_layers = 27 + config.num_attention_heads = 16 + config.add_bias_linear = True + config.add_qkv_bias = True + config.hidden_size = 1152 + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + config.ffn_hidden_size = 4304 + config.gated_linear_unit = False + config.activation_func = fast_gelu + config.kv_channels = 72 + config.num_query_groups = 16 + config.layernorm_zero_centered_gamma = False + config.apply_query_key_layer_scaling = apply_query_key_layer_scaling + config.bias_activation_fusion = False + config.bias_dropout_fusion = False + config.attention_softmax_in_fp32 = True + config.normalization = 'LayerNorm' + config.apply_rope_fusion = False + config.qk_layernorm = False + config.layernorm_epsilon = 1e-6 + elif config.vision_model_type == "internvit": + config.num_layers = 45 + config.num_attention_heads = ((24 // config.tensor_model_parallel_size) + 1) * config.tensor_model_parallel_size + config.num_query_groups = config.num_attention_heads + config.add_bias_linear = True + config.add_qkv_bias = False + config.hidden_size = 3200 + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + config.ffn_hidden_size = 12800 + config.gated_linear_unit = False + config.activation_func = torch.nn.functional.gelu + config.layernorm_zero_centered_gamma = False + config.apply_query_key_layer_scaling = apply_query_key_layer_scaling + config.bias_activation_fusion = False + config.bias_dropout_fusion = False + config.attention_softmax_in_fp32 = True + config.normalization = 'RMSNorm' + config.layernorm_epsilon = 1e-6 + config.apply_rope_fusion = False + elif config.vision_model_type == "radio": + config.num_layers = 32 + config.num_attention_heads = 16 + config.add_bias_linear = True + config.add_qkv_bias = True + config.hidden_size = 1280 + config.ffn_hidden_size = 5120 + config.gated_linear_unit = False + config.activation_func = fast_gelu + config.kv_channels = 80 + config.num_query_groups = 16 + config.layernorm_zero_centered_gamma = False + config.apply_query_key_layer_scaling = apply_query_key_layer_scaling + config.bias_activation_fusion = False + config.bias_dropout_fusion = False + config.attention_softmax_in_fp32 = True + config.normalization = 'LayerNorm' + config.apply_rope_fusion = False + config.qk_layernorm = False + config.layernorm_epsilon = 1e-6 + elif config.vision_model_type.startswith("huggingface"): + # Loaded from HuggingFace config file. + pass + else: + raise ValueError(f"unknown vision model type {config.vision_model_type}") + + return config + + +def get_vision_projection_config(config, hidden_size): + config.gated_linear_unit = False + config.bias_activation_fusion = False + config.add_bias_linear = False + config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model. + if config.language_model_type == "llama3_8b": + config.ffn_hidden_size = 14336 + config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "llama3.1_8b": + config.ffn_hidden_size = 4096 + config.activation_func = torch.nn.functional.gelu + config.layernorm_epsilon = 1e-5 + config.add_bias_linear = True + config.normalization = "LayerNorm" + elif config.language_model_type == "mistral_7b": + config.ffn_hidden_size = 14336 + config.activation_func = torch.nn.functional.gelu + config.normalization = None + elif config.language_model_type == "yi-34b": + config.ffn_hidden_size = 20480 + config.normalization = "LayerNorm" + config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "qwen2.5_7B": + config.ffn_hidden_size = 3584 + config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "qwen2.0_72B": + config.ffn_hidden_size = 29568 + config.normalization = "LayerNorm" + config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "llama3.2_1b": + config.ffn_hidden_size = 2048 + config.activation_func = torch.nn.functional.gelu + config.normalization = "LayerNorm" + elif config.language_model_type.startswith("huggingface"): + config.activation_func = torch.nn.functional.gelu + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(config.huggingface_model_name_or_path) + if "qwen" in hf_config.model_type: + config.ffn_hidden_size = 1536 + else: + raise ValueError(f"unknown language model type {config.language_model_type}") + + return config + + +@dataclass +class EvaluationConfig: + """Evaluation related configuration.""" + task: str + + temperature: float = 1.0 + top_p: float = 0.0 + top_k: int = 0 + + out_seq_length: int = 32 + + output_path: str = "" + + input_image_path: str = "" + gt_path: str = "" + + num_partitions: int = 1 + partition_id: int = 0 + num_samples_per_partition: int = 0 diff --git a/examples/multimodal/convert_llava_pretrain_to_wds.py b/examples/multimodal/convert_llava_pretrain_to_wds.py new file mode 100644 index 0000000000..0092aef246 --- /dev/null +++ b/examples/multimodal/convert_llava_pretrain_to_wds.py @@ -0,0 +1,31 @@ +import json +import os +import webdataset as wds + +from tqdm import tqdm + +llava_pretrain_dir = '' + +# Paths to the dataset files +json_file = os.path.join(llava_pretrain_dir, 'blip_laion_cc_sbu_558k.json') +output = os.path.join(llava_pretrain_dir, 'wds') + +if not os.path.exists(output): + os.mkdir(output) + +# Load data +with open(json_file, 'r') as f: + data = json.load(f) + +with wds.ShardWriter(os.path.join(output, 'pretrain-%d.tar'), maxcount=10000) as shard_writer: + for entry in tqdm(data): + with open(os.path.join(llava_pretrain_dir, entry['image']), "rb") as img_file: + image_data = img_file.read() + sample = { + "__key__": entry['id'], + "jpg": image_data, + "json": json.dumps(entry['conversations']).encode("utf-8"), + } + shard_writer.write(sample) + +print(f"Dataset successfully converted to wds") diff --git a/examples/multimodal/dataloader_provider.py b/examples/multimodal/dataloader_provider.py new file mode 100644 index 0000000000..aef2186834 --- /dev/null +++ b/examples/multimodal/dataloader_provider.py @@ -0,0 +1,169 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import os + +import torch +from dataset_helpers import TaskEncoder, print_error_handler + +from megatron.core import parallel_state +from megatron.energon import ( + LimitDataset, + RepeatDataset, + WorkerConfig, + get_loader, + get_savable_loader, + get_train_dataset, + get_val_datasets, +) +from megatron.core.num_microbatches_calculator import get_num_microbatches +from megatron.core.parallel_state import get_tensor_model_parallel_rank, get_pipeline_model_parallel_world_size, get_pipeline_model_parallel_rank +from megatron.training import get_args +from megatron.training.checkpointing import get_checkpoint_name + + +def datasets_provider(worker_config=None): + """Create multimodal train, validation and test datasets.""" + args = get_args() + + dname = args.data_path[0] if type(args.data_path) is list else args.data_path + train_dataset = get_train_dataset( + dname, + batch_size=args.micro_batch_size, + task_encoder=TaskEncoder(), + worker_config=worker_config, + max_samples_per_sequence=None, + shuffle_buffer_size=None, + packing_buffer_size=args.packing_buffer_size, + handler=print_error_handler, + image_decode="pil", + ) + + val_datasets = get_val_datasets( + dname, + batch_size=args.micro_batch_size, + # This is the total number over all workers + # limit=args.eval_iters * get_num_microbatches(), + task_encoder=TaskEncoder(), + worker_config=worker_config, + packing_buffer_size=args.packing_buffer_size, + handler=print_error_handler, + image_decode="pil", + ) + val_datasets_without_source_datasets = [ + # Limit the dataset to eval_iters * num_microbatches + LimitDataset( + # Repeat the inner dataset in case it's too short + RepeatDataset(val_ds, worker_config=worker_config), + length=args.eval_iters * get_num_microbatches(), + worker_config=worker_config, + reset_after_epoch=True, + ) + for val_ds, _src_ds in val_datasets + ] + + return train_dataset, val_datasets_without_source_datasets, None + + +def is_first_or_last_stage(pp_size, encoder_pipeline_model_parallel_size): + """Check if the current pipeline parallel stage is the first or last stage.""" + if pp_size == 1: # No pipeline parallelism. + return True + + is_valid_rank = False + pp_rank = get_pipeline_model_parallel_rank() + if encoder_pipeline_model_parallel_size == 0: + # No separate pipeline stage for the vision model. Run the dataloader on the first and last pipeline stage. + is_valid_rank = pp_rank in (0, pp_size-1) + elif encoder_pipeline_model_parallel_size == 1: + # Separate pipeline stage for the vision model. Run the dataloader on the first vision and LM stage and last LM stage. + is_valid_rank = pp_rank in (0, 1, pp_size-1) + else: + raise NotImplementedError("encoder-pipeline-model-parallel-size > 1 is not supported yet") + + return is_valid_rank + + +def is_dataloader_rank(encoder_pipeline_model_parallel_size): + """Check if we should have the dataloader on this tensor and pipeline parallel rank.""" + # Run dataloader only on the first tensor parallel rank (will be broadcasted to others). + is_first_rank = get_tensor_model_parallel_rank() == 0 + + pp_size = get_pipeline_model_parallel_world_size() + is_first_rank = is_first_rank and is_first_or_last_stage(pp_size, encoder_pipeline_model_parallel_size) + + return is_first_rank + + +def train_valid_test_dataloaders_provider(train_val_test_num_samples): + """Build multimodal train, validation and test dataloaders.""" + args = get_args() + + # Dataloader is only on specific ranks. + if not is_dataloader_rank(args.encoder_pipeline_model_parallel_size): + return None, None, None + + worker_debug_path = None + worker_log_level = 0 + + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + data_parallel_group = parallel_state.get_data_parallel_group() + + worker_config = WorkerConfig( + rank=rank, + world_size=world_size, + num_workers=args.num_workers, + data_parallel_group=data_parallel_group, + worker_debug_path=worker_debug_path, + worker_log_level=worker_log_level, + ) + train_ds, valid_ds1, test_ds = datasets_provider(worker_config) + + train_dataloader = get_savable_loader(train_ds, worker_config=worker_config) + if args.load is not None: + if getattr(args, "dataloader_save", None): + dp_rank = parallel_state.get_data_parallel_rank() + data_save_name = get_checkpoint_name( + args.dataloader_save, + args.iteration, + pipeline_rank=0, # Only the first pipeline parallel rank stores the dataloader checkpoint. + basename=f"train_dataloader_dprank{dp_rank:03d}.pt", + ) + if os.path.exists(data_save_name): + try: + dataset_state_dict = torch.load(data_save_name, map_location="cpu") + train_dataloader.restore_state_rank(dataset_state_dict["dataloader_state_dict"]) + print(f"restored dataset state from {data_save_name}") + except Exception as e: + print("loading dataset state failed. Skipping. " + str(e)) + else: + print(f"dataset state {data_save_name} does not exist") + + valid_dataloader = [ + EnergonDataloader(get_loader(valid_ds, worker_config=worker_config)) + for valid_ds in valid_ds1 + ] + test_dataloader = None + + return EnergonDataloader(train_dataloader), valid_dataloader, EnergonDataloader(test_dataloader) + + +class EnergonDataloader: + """A wrapper to use Megatron Energon dataloader with the Megatron-LM training loop.""" + def __init__(self, dataloader): + self._dataloader = dataloader + self._iter = iter(cyclic_iter(dataloader)) + + def __next__(self): + return self._iter.__next__() + + def __iter__(self): + return self._iter.__iter__() + + def save_state(self): + return self._dataloader.save_state_rank() + + +def cyclic_iter(iter): + while True: + for x in iter: + yield x diff --git a/examples/multimodal/dataset_helpers.py b/examples/multimodal/dataset_helpers.py new file mode 100644 index 0000000000..39fa44b338 --- /dev/null +++ b/examples/multimodal/dataset_helpers.py @@ -0,0 +1,906 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import bisect +import dataclasses +import json +import re +import sys +import traceback +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +from image_processing import find_closest_aspect_ratio, find_closest_area_weighted_aspect_ratio, get_visual_transform +from PIL import Image +from torchvision.transforms import ToPILImage +import numpy as np +import torch + +from energon_util import OfflineTargetAspectRatioSample, SampleListSample +from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN, VIDEO_TOKEN +from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings +from megatron.energon import ( + Batch, + CaptioningSample, + DefaultTaskEncoder, + OCRSample, + Sample, + SimilarityInterleavedSample, + VQASample, + MultiChoiceVQASample +) +from megatron.energon.task_encoder.base import stateless +from megatron.training import get_args, get_tokenizer + + +@dataclass +class ImageTaskSample(Sample): + __key__: str + __restore_key__: Tuple[Union[str, int, tuple], ...] + __subflavor__: Dict + __subflavors__: Dict + # (c, h, w) + imgs: List[torch.Tensor] + num_tiles: List[int] + tokens: torch.Tensor + total_len: int # Total token count in the sample, including text and image tokens + labels: torch.Tensor = None + + +@dataclass +class ImageTaskSamplePacked(Sample): + """Dataclass to store a single packed sample (not a batch). + + P = Number of sub-samples in the packed sample + seq_len = Total sequence length + num_imgs = Number of images across all samples in the packed sample + """ + + __key__: str # Sample name + __restore_key__: Tuple[Union[str, int, tuple], ...] + __subflavor__: Dict # Sample metadata. Deprecated. + __subflavors__: Dict # Sample metadata. + tokens: torch.Tensor # Input tokens packed into a single tensor (seq_len,) + labels: torch.Tensor # Target tokens packed into a single tensor (seq_len,) + imgs: List[torch.Tensor] # Input images + num_tiles: List[int] # Number of tiles for each image of each sample (num_imgs) + max_length: int # Maximum length across sub-samples. + cu_lengths: List[int] # Cumulative length of each sub-sample in this packed sample incl. text and image tokens (P,) + + +# Typing for the resulting batch data after encode_batch() +@dataclass +class ImageTaskBatchPacked(Batch): + """Dataclass to store a batch of packed samples. + + N = Batch size + P = Number of samples in the packed sample + seq_len = Maximum sequence length + num_imgs = Number of images across all samples in the packed sample + """ + + __key__: List[str] # Sample names + __restore_key__: Tuple[Union[str, int, tuple], ...] + __subflavor__: Dict # Sample metadata. Deprecated. + __subflavors__: List[Dict] # Sample metadatas. + tokens: torch.Tensor # Input tokens packed and padded (N, seq_len) + labels: torch.Tensor # Target tokens packed and padded (N, seq_len) + imgs: torch.Tensor # All image tiles stacked into a single tensor (num_tiles, C, H, W) + num_tiles: List[List[int]] # Number of tiles per image (N, num_imgs) + max_lengths: List[int] # Maximum length across sub-samples (N,) + cu_lengths: List[List[int]] # Cumulative length of each sub-sample in each packed sample of the batch (N, P) + + +# Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L19 +# Copyright (c) 2024 LLaMA-Factory. Apache license 2.0. +def search_for_fit(numbers: List[int], capacity: int) -> int: + """Finds the index of largest number that fits into the knapsack with the given capacity.""" + index = bisect.bisect(numbers, capacity) + return -1 if index == 0 else (index - 1) + + +# Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L27 +# Copyright (c) 2024 LLaMA-Factory. Apache license 2.0. +def greedy_knapsack(item_sizes: List[int], samples: List, max_capacity: int) -> List: + """Greedy algorithm with binary search for the knapsack problem. + + Pack as many samples as possible given a maximum capacity and capacities of individual samples. + Used if sequence packing is enabled. + """ + assert len(item_sizes) == len(samples), "sample lengths and samples must have the same length." + + knapsacks = [] + + if len(item_sizes) == 0: + return knapsacks + + # Sort sample lengths and samples together. + sorted_item_sizes, sorted_samples = zip(*sorted(zip(item_sizes, samples), key=lambda x: x[0])) + sorted_item_sizes = list(sorted_item_sizes) + sorted_samples = list(sorted_samples) + + # Check if all samples fit in the knapsack capacity. + if sorted_item_sizes[-1] > max_capacity: + raise ValueError(f"knapsack: A sample is larger {sorted_item_sizes[-1]} than the max_sequence_length {max_capacity}.") + + while sorted_item_sizes: + current_knapsack = [] + remaining_capacity = max_capacity + + while True: + idx = search_for_fit(sorted_item_sizes, remaining_capacity) + if idx == -1: + break # Can't fit more samples. + + remaining_capacity -= sorted_item_sizes[idx] + + sorted_item_sizes.pop(idx) + sample = sorted_samples.pop(idx) + current_knapsack.append(sample) + + knapsacks.append(current_knapsack) + + return knapsacks + + +class TaskEncoder(DefaultTaskEncoder[OCRSample, OCRSample, ImageTaskBatchPacked, dict]): + """A simple task encoder for VLMs.""" + + def __init__( + self + ): + super().__init__() + + self.args = get_args() + + self.tokenizer = get_tokenizer() + with open(self.args.prompt_path, "r") as f: + self.manual_prompts = json.load(f) + self.dataloader_seq_length = self.args.dataloader_seq_length # Always return samples of this length. + self.packing_seq_length = self.args.packing_seq_length # Packing sequence length, if packing is enabled. + self.is_packing_enabled = self.args.packing_buffer_size is not None and self.args.packing_buffer_size > 0 + + if self.dataloader_seq_length and self.packing_seq_length: + assert self.dataloader_seq_length >= self.packing_seq_length, "dataloader sequence length must be greater than or equal to the packing sequence length" + + if self.is_packing_enabled: + assert self.packing_seq_length > 0, "packing sequence length must be set" + + self.num_image_embeddings_per_tile = get_num_image_embeddings( + self.args.img_h, + self.args.img_w, + self.args.patch_dim, + self.args.vision_model_type, + self.args.disable_vision_class_token, + 1, + self.args.pixel_shuffle, + self.args.use_tile_tags, + ) + + self.txt_to_token_dict = {} + + self.img_h, self.img_w = self.args.img_h, self.args.img_w + self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + # This map is used to reduce the number of tiles used per image if the number of tokens is + # larger than the decoder_seq_length. + self.num_tiles_degradation_map = {12:8, 8:6, 6:4, 4:2, 2:1, 1:1} + + self.find_closest_aspect_ratio_fn = ( + find_closest_area_weighted_aspect_ratio if self.args.use_area_weighted_aspect_ratio + else find_closest_aspect_ratio) + + def _get_total_seq_length(self, input_ids, num_tiles): + """Calculate expected sequence length given text tokens length and number of tiles.""" + total_num_images = len(num_tiles) + total_num_tiles = sum(num_tiles) + total_len = len(input_ids) + total_num_tiles * self.num_image_embeddings_per_tile - total_num_images + return total_len + + def _truncate_for_packing(self, input_ids, target, num_tiles): + """Truncate tokens and labels if they exceed packing sequence length.""" + total_num_images = len(num_tiles) + total_num_tiles = sum(num_tiles) + total_img_embeddings_len = total_num_tiles * self.num_image_embeddings_per_tile + max_text_tokens = self.packing_seq_length - total_img_embeddings_len + total_num_images + + input_ids = input_ids[:max_text_tokens] + target = target[:max_text_tokens] + + # If truncate causes all labels to be ignored, then skip the sample + if (target == IGNORE_INDEX).all(): + raise ValueError(f"all targets will be ignored after truncation: {input_ids}") + + return input_ids, target + + @stateless(restore_seeds=True) + def encode_sample(self, sample: Union[CaptioningSample, OCRSample, VQASample, SimilarityInterleavedSample]): + if isinstance(sample, OCRSample): + if "pdfa" in sample.__key__: + yield self.combined_ocr_encoder(sample, task_type='encode_pdf') + elif "multi" in sample.__key__: + yield self.combined_ocr_encoder(sample, task_type='_encode_ocr') + else: + yield self.combined_ocr_encoder(sample, task_type='encode_ocr_ref') + elif isinstance(sample, CaptioningSample): + yield self.encode_captioning(sample) + elif isinstance(sample, VQASample): + is_llava_training = sample.__subflavors__["is_llava_training"] if "is_llava_training" in sample.__subflavors__ else False + + if "llava" in sample.__key__ or is_llava_training: + yield self.encode_llava_pretrain(sample) + else: + yield self.encode_any_single_turn_vqa(sample) + elif isinstance(sample, SimilarityInterleavedSample): + yield self.encode_llava_sft(sample) + elif isinstance(sample, MultiChoiceVQASample): + yield self.encode_any_single_turn_vqa(sample) + # Because the SampleListSample is defined in the Megatron module but loaded by the Energon + # library, we need to resort to the more brittle check: + elif type(sample).__name__ == "SampleListSample": + yield self.encode_sample_list(sample) + else: + raise NotImplementedError("Sample format not supported", sample) + + def encode_captioning(self, sample: CaptioningSample): + """Encode CaptioningSample.""" + augment = sample.__subflavors__.get("augmentation") + + imgs = get_visual_transform( + sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment, + self.args.vision_model_type, find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn + ) + num_tiles = [len(imgs)] + + prompt_list = self.manual_prompts["CaptioningPretraining"]["raw"] + + prompt_idx = np.random.randint(len(prompt_list)) + cur_prompt = prompt_list[prompt_idx] + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + "\n" + + caption = sample.caption.strip() + + split_by_line_flag = sample.__subflavors__.get("SplitByLine") + if split_by_line_flag: + caption_list = caption.split('\n') + caption = np.random.choice(caption_list) + + conv = [ + # Note: no system message. + {"role": "user", "content": cur_prompt}, + {"role": "assistant", "content": caption}, + ] + + input_ids, target = self.tokenizer.tokenize_conversation(conv, True, False) + + if self.is_packing_enabled: + input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) + + return ImageTaskSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavor__=None, + __subflavors__=sample.__subflavors__, + imgs=imgs, + num_tiles=num_tiles, + tokens=torch.tensor(input_ids), + labels=torch.tensor(target), + total_len=self._get_total_seq_length(input_ids, num_tiles), + ) + + def encode_llava_pretrain(self, sample: VQASample): + """Encode pretrain sample in LLAVA style.""" + augment = sample.__subflavors__.get("augmentation", False) + + imgs = get_visual_transform( + sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment, + self.args.vision_model_type, find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn + ) + num_tiles = [len(imgs)] + + # LLAVA training: override text-prompt with just the image. + conv = [ + # Note: no system message. + {"role": "user", "content": IMAGE_TOKEN + "\n"}, + {"role": "assistant", "content": sample.answers}, + ] + + input_ids, target = self.tokenizer.tokenize_conversation(conv, True, False) + + if self.is_packing_enabled: + input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) + + return ImageTaskSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavor__=None, + __subflavors__=sample.__subflavors__, + imgs=imgs, + num_tiles=num_tiles, + tokens=torch.tensor(input_ids), + labels=torch.tensor(target), + total_len=self._get_total_seq_length(input_ids, num_tiles), + ) + + def encode_sample_list(self, samples: SampleListSample): + """We encode the list of samples using encode_llava_sft on each sample.""" + error_msg = ("You probably don't want to use online packing since SampleListSample is " + "usually used along offline packing.") + assert not self.is_packing_enabled, error_msg + encoded_samples = [] + current_length = 0 + for sample in samples.samples: + encoded_sample = self.encode_llava_sft(sample, truncate_for_sample_list_packing=True) + if current_length + encoded_sample.total_len > self.packing_seq_length: + break + else: + encoded_samples.append(encoded_sample) + current_length += encoded_sample.total_len + return self.pack_selected_samples(encoded_samples) + + def encode_llava_sft(self, sample: Union[SimilarityInterleavedSample, OfflineTargetAspectRatioSample], truncate_for_sample_list_packing=False): + """Encode SFT sample.""" + augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False + has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False + + # If the target aspect ratio are provided by the dataset, we use them instead of computing + # them with the self.find_closest_aspect_ratio_fn function. + local_find_closest_aspect_ratio_fn = self.find_closest_aspect_ratio_fn + if type(sample).__name__ == "OfflineTargetAspectRatioSample": + target_aspect_ratio = tuple(sample.target_aspect_ratio[0]) + assert target_aspect_ratio is not None, "Sample of type OfflineTargetAspectRatioSample needs to define the target aspect ratio." + local_find_closest_aspect_ratio_fn = lambda *args, **kwargs: target_aspect_ratio + + has_image = False + # We infer whether the sample has image or not. + if hasattr(sample, "images") and not has_video: + # If this is a text-only sample and we are freezing the LM, + # then use a dummy input image. + if len(sample.images) == 0 and self.args.freeze_LM: + empty_img = Image.new('RGB', (self.args.img_w, self.args.img_h), (255, 255, 255)) + sample.images.append(empty_img) + if len(sample.images) > 0: + has_image = True + + # Note: Some tokenizers may ignore the system prompt. + conversation = [{"role": "system", "content": "Answer the questions."}] + # Format the conversation as a list of "user" / "assistant" turns. + for text in sample.texts: + error_msg = f"unexpected role {text['from']} in {sample.texts}" + assert text["from"] in ["human", "gpt"], error_msg + conversation.append({ + "role": "user" if text["from"] == "human" else "assistant", + "content": text["value"]}) + + # Replace the image tags with IMAGE_TOKEN and count the number of image tags + number_image_tags = 0 + image_tag_ids_list = [] + for turn in conversation: + if turn["role"] == "user": + image_tag_ids = [int(x) - 1 for x in re.findall(r"", turn["content"])] + image_tag_ids_list.extend(image_tag_ids) + turn["content"] = re.sub(r"", IMAGE_TOKEN, turn["content"]) + # For videos, we use the image token to locate where to put the frames. + if has_video: + turn["content"] = turn["content"].replace(VIDEO_TOKEN, IMAGE_TOKEN) + number_image_tags += turn["content"].count(IMAGE_TOKEN) + + # We re-order the images in sample.images according to how they appear in the conversation. + if len(image_tag_ids_list) > 0: + sample.images = [sample.images[idx] for idx in image_tag_ids_list] + + # If there is only one image, but several image tags, we assume all the tags refer to the + # same image and duplicate the image: + if not has_video and len(sample.images) == 1 and number_image_tags > 1: + sample.images = sample.images * number_image_tags + + # We currently only support one video per sample. + number_of_images = 1 if has_video else len(sample.images) + # Fail if there are more image or video tags than image or videos: + error_msg = ( + f"Found {number_image_tags} image tags for {number_of_images} images. {sample.texts}") + assert number_image_tags <= number_of_images, error_msg + + # If there are less image of video tags than image or videos, prepend the tags to the first + # user message: + if number_image_tags < number_of_images: + for turn in conversation: + if turn["role"] == "user": + turn["content"] = IMAGE_TOKEN*(number_of_images-number_image_tags) + "\n" + turn["content"] + break + + input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) + + if has_image: + imgs = [] + num_tiles = [] + max_num_tiles = self.args.max_num_tiles + # We keep a buffer of 4 tokens for the question, + # the rest can be used for image tokens. + max_image_token_allowed = self.args.decoder_seq_length - len(input_ids) - 4 + # We start by extracting as many tiles per image as possible, and decrease the max + # number of tiles if there are too many image tokens. + while True: + imgs = [] + num_tiles = [] + for img in sample.images: + img_tiles = get_visual_transform( + img, self.img_h, self.img_w, self.args.use_tiling, max_num_tiles, + self.args.use_thumbnail, augment, self.args.vision_model_type, + find_closest_aspect_ratio_fn=local_find_closest_aspect_ratio_fn) + imgs += img_tiles + num_tiles += [len(img_tiles)] + if max_num_tiles == 1: + break + if sum(num_tiles) * self.num_image_embeddings_per_tile > max_image_token_allowed: + if max_num_tiles in self.num_tiles_degradation_map: + max_num_tiles = self.num_tiles_degradation_map[max_num_tiles] + else: + raise RuntimeError(( + f"Tried to decrease the number of tiles {max_num_tiles} but it's not ", + f"defined in the degradation map {self.num_tiles_degradation_map}")) + else: + break + elif has_video: + # We don't use tiling for videos to limit the number of tokens. + use_tiling=False + # Grab the selected frames of the video as a tensor with shape + # fhwc: (num_frames, num_channels, height, width). + video_fchw = sample.images.frames + if video_fchw.shape[0] == 0: + raise ValueError(f"Video {sample.__key__} {sample.__restore_key__} {sample.texts} has no frames.") + selected_frames = torch.linspace( + 0, video_fchw.shape[0] - 1, self.args.num_frames).long() + video_fchw = video_fchw[selected_frames] + imgs = [] + for video_chw in video_fchw: + to_pil = ToPILImage() + video_chw = to_pil(video_chw) + imgs += get_visual_transform( + video_chw, self.img_h, self.img_w, use_tiling, self.args.max_num_tiles, + self.args.use_thumbnail, augment, self.args.vision_model_type, + find_closest_aspect_ratio_fn=local_find_closest_aspect_ratio_fn) + num_tiles = [len(imgs)] + else: + imgs = num_tiles = [] + + if self.is_packing_enabled or truncate_for_sample_list_packing: + input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) + + # Some final checks with respect to the number of image tokens and images on the tokenized + # conversation. There can still be errors, for instance if a non-video sample happens to + # have our pre-defined video token, or if the packing truncation removed a necessary image + # tag. + number_image_token = np.sum(input_ids == self.img_token_id) + error_msg = ( + f"Found {number_image_token} image tokens for len({num_tiles}) = {len(num_tiles)} image tiles in {conversation}.") + assert number_image_token == len(num_tiles), error_msg + error_msg = ( + f"Found sum({num_tiles}) = {np.sum(num_tiles)} tiles for {len(imgs)} images in {conversation}.") + assert np.sum(num_tiles) == len(imgs), error_msg + + # We need to ensure that there are at least some trainable tokens in the sample. + assert self.target_has_trainable_tokens(input_ids, num_tiles, target), "Sample has no trainable tokens." + + return ImageTaskSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavor__=None, + __subflavors__=sample.__subflavors__, + imgs=imgs, + num_tiles=num_tiles, + tokens=torch.tensor(input_ids), + labels=torch.tensor(target), + total_len=self._get_total_seq_length(input_ids, num_tiles), + ) + + def target_has_trainable_tokens(self, input_ids, num_tiles, target): + # Compute the loss mask based on extending the image tags with the proper + # number of image tokens, extracting the first self.args.decoder_seq_length tokens, and + # ensuring that some of these tokens have a loss mask > 0. + # Note that this is a bit hacky because we reproduce here parts of the logics which are in + # the model itself. Ideally, the data sampler would return the already processed inputs + # and targets to avoid this duplication. + expanded_target = target.copy() + expanded_target[input_ids==self.img_token_id] = self.img_token_id + expanded_target = self.replace_value_with_repetition( + expanded_target, self.img_token_id, + self.num_image_embeddings_per_tile * np.array(num_tiles), IGNORE_INDEX) + loss_mask = torch.ones(torch.tensor(expanded_target).size(), dtype=torch.float) + loss_mask[expanded_target == self.tokenizer.pad] = 0.0 # mask paddings + loss_mask[expanded_target == IGNORE_INDEX] = 0.0 # mask prompts + loss_mask = torch.cat((loss_mask[1:], torch.zeros((1,)))) + loss_mask = loss_mask[:self.args.decoder_seq_length] + return torch.sum(loss_mask) > 0 + + def replace_value_with_repetition(self, arr, token_to_replace, num_repetition, new_token): + """ + Replace every occurrence of value V in the input array with R repetitions of W. + + Args: + arr (Array): Input array to be modified + token_to_replace: token to be replaced + new_token: new token + num_repetition (Array): number of repetition of new token. + + Returns: + Array: New array with token_to_replace replaced by num_repetition repetitions of + new_token + """ + error_msg = "The number of image tokens must match the length of the tile tensor." + assert np.sum(arr==token_to_replace) == len(num_repetition), error_msg + result = [] + idx = 0 + for item in arr: + if item == token_to_replace: + # If the current item matches token_to_replace, add R copies of W + result.extend([new_token] * num_repetition[idx]) + idx += 1 + else: + # Otherwise, keep the original item + result.append(item) + + return np.array(result) + + def encode_any_single_turn_vqa(self, sample): + """Encode MultiChoiceVQA or VQA sample.""" + augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False + has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False + + if has_video: + # Grab the selected frames of the video as a tensor with shape + # fhwc: (num_frames, height, width, num_channels). + video_fhwc = sample.image.permute(0, 2, 3, 1) + selected_frames = torch.linspace( + 0, video_fhwc.shape[0] - 1, self.args.num_frames).long() + video_frame_fhwc = video_fhwc[selected_frames] + imgs = [] + for video_frame_hwc in video_frame_fhwc: + imgs += get_visual_transform( + video_frame_hwc, self.img_h, self.img_w, + self.args.use_tiling, self.args.max_num_tiles, + self.args.use_thumbnail, augment, self.args.vision_model_type, + find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn) + else: + imgs = get_visual_transform( + sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, + self.args.use_thumbnail, augment, self.args.vision_model_type, + find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn + ) + + num_tiles = [len(imgs)] + + if isinstance(sample, MultiChoiceVQASample): + cur_prompt = format_multichoice_question(sample.context, sample.choices) + if IMAGE_TOKEN not in cur_prompt: + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + cur_answer = format_multichoice_answer(sample.correct_choice_idx) + elif isinstance(sample, VQASample): + if 'docvqa' in sample.__key__: + prompt_list = self.manual_prompts["VQASFT"]["docvqa"] + elif sample.__subflavors__.get("VQASFT"): + prompt_list = self.manual_prompts["VQASFT"]["raw"] + else: + prompt_list = ["{}"] + + prompt_idx = np.random.randint(len(prompt_list)) + cur_prompt = prompt_list[prompt_idx] + + cur_prompt = cur_prompt.format(sample.context) + + if IMAGE_TOKEN not in cur_prompt: + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + + if isinstance(sample.answers, list): + answer_list = sample.answers + weight_list = np.array(sample.answer_weights).astype(np.float32) + weight_list = weight_list / np.sum(weight_list) + answer_idx = np.random.choice(weight_list.shape[0], 1, p=weight_list)[0] + cur_answer = answer_list[answer_idx] + else: + cur_answer = sample.answers + else: + raise NotImplementedError("Unsupported data type provided", sample) + + conversation = [ + {"role": "system", "content": "Answer the questions."}, + {"role": "user", "content": cur_prompt}, + {"role": "assistant", "content": str(cur_answer)}, + ] + + input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) + + if self.is_packing_enabled: + input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) + + return ImageTaskSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavor__=None, + __subflavors__=sample.__subflavors__, + imgs=imgs, + num_tiles=num_tiles, + tokens=torch.tensor(input_ids), + labels=torch.tensor(target), + total_len=self._get_total_seq_length(input_ids, num_tiles), + ) + + def combined_ocr_encoder(self, sample, task_type): + """Encode OCR samples.""" + augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False + + if task_type == "encode_pdf": + sample, cur_prompt, cur_answer = self.encode_pdf_prompt(sample) + elif task_type == "encode_ocr_ref": + sample, cur_prompt, cur_answer = self.encode_ocr_ref_prompt(sample) + elif task_type == "_encode_ocr": + sample, cur_prompt, cur_answer = self.encode_ocr_prompt(sample) + + imgs = get_visual_transform( + sample.image, self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, + self.args.use_thumbnail, augment, self.args.vision_model_type, + find_closest_aspect_ratio_fn=self.find_closest_aspect_ratio_fn + ) + num_tiles = [len(imgs)] + + conversation = [ + {"role": "system", "content": "Answer the questions."}, + {"role": "user", "content": cur_prompt}, + {"role": "assistant", "content": str(cur_answer)}, + ] + + input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False) + + if self.is_packing_enabled: + input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles) + + return ImageTaskSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavor__=None, + __subflavors__=sample.__subflavors__, + imgs=imgs, + num_tiles=num_tiles, + tokens=torch.tensor(input_ids), + labels=torch.tensor(target), + total_len=self._get_total_seq_length(input_ids, num_tiles), + ) + + def encode_pdf_prompt(self, sample: OCRSample) -> ImageTaskSample: + """Encode OCR sample.""" + prompt_list = self.manual_prompts["DocPretraining"]["raw"] + prompt_idx = np.random.randint(len(prompt_list)) + cur_prompt = prompt_list[prompt_idx] + if IMAGE_TOKEN not in cur_prompt: + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + + # Make sure there is no extra IMAGE_TOKEN tag. + sample.text = sample.text.replace(IMAGE_TOKEN, "") + + caption = sample.text.strip() + + split_by_line_flag = sample.__subflavors__.get("SplitByLine") + if split_by_line_flag: + caption_list = caption.split('\n') + caption = np.random.choice(caption_list) + cur_answer = caption + + return sample, cur_prompt, cur_answer + + def encode_ocr_ref_prompt(self, sample: OCRSample) -> ImageTaskSample: + """Encode OCR sample.""" + ref = sample.text + region = sample.words_boxes + + # Make sure there is no extra IMAGE_TOKEN tag + ref = ref.replace(IMAGE_TOKEN, "") + + if len(region) == 4: + region = f"({region[0]},{region[1]}),({region[2]},{region[3]})" + else: + region = f"({region[0]},{region[1]}),({region[2]},{region[3]}),({region[4]},{region[5]}),({region[6]},{region[7]})" + + # Randomly choose between two tasks + task_idx = np.random.randint(2) + if task_idx == 0: + # Referring Grounding + prompt_list = self.manual_prompts["DocPretraining"]["referring_grounding"] + prompt_content = ref + answer = region + else: + # Grounded OCR + prompt_list = self.manual_prompts["DocPretraining"]["grounded_ocr"] + prompt_content = region + answer = ref + + prompt_idx = np.random.randint(len(prompt_list)) + cur_prompt = prompt_list[prompt_idx] + cur_prompt = cur_prompt.format(prompt_content) + if IMAGE_TOKEN not in cur_prompt: + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + + return sample, cur_prompt, answer + + def bbox_coord_to_label(self, text, bbox): + """Format bbox coordinates as text.""" + assert len(bbox) == 4 or len(bbox) == 8 + + # Make sure there is no extra IMAGE_TOKEN tag + text = text.replace(IMAGE_TOKEN, "") + + if len(bbox) == 4: + label_str = f"{text}({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]})" + else: + label_str = f"{text}({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]}),({bbox[4]},{bbox[5]}),({bbox[6]},{bbox[7]})" + + return label_str + + def encode_ocr_prompt(self, sample: OCRSample) -> ImageTaskSample: + """Encode OCR sample.""" + if isinstance(sample.words_boxes[0], int): + answer = self.bbox_coord_to_label(sample.text, sample.words_boxes) + elif isinstance(sample.words_boxes[0], list): + answer = "" + for i, bbox in enumerate(sample.words_boxes): + answer += self.bbox_coord_to_label(sample.words_text[i], bbox) + + prompt_list = self.manual_prompts["DocPretraining"]["ocr_multi"] + prompt_idx = np.random.randint(len(prompt_list)) + cur_prompt = prompt_list[prompt_idx] + + if IMAGE_TOKEN not in cur_prompt: + cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + cur_answer = answer + + return sample, cur_prompt, cur_answer + + def batch(self, samples: List[Union[ImageTaskSample, ImageTaskSamplePacked]]) -> ImageTaskBatchPacked: + # Stack images to [num_tiles, c, h, w]. If there are no images (text-only), then use a dummy image. + imgs = [img for s in samples for img in s.imgs] + if len(imgs) > 0: + imgs = torch.stack(imgs) + else: + imgs = torch.tensor([[0]], dtype=torch.float32) + + # If the user hasn't defined a target dataloader sequence length, then use the max along the sample lengths. + max_seq_len = self.dataloader_seq_length + if not max_seq_len: + max_seq_len = max(len(s.tokens) for s in samples) + + tokens = np.full((len(samples), max_seq_len), self.tokenizer.pad, dtype=np.int64) + # +1 to accommodate shift to left by one later. + labels = np.full((len(samples), max_seq_len + 1), self.tokenizer.pad, dtype=np.int64) + + for i, s in enumerate(samples): + # If the sample/target length exceeds the target sequence length, then truncate. + text_len = min(max_seq_len, len(s.tokens)) + target_len = min(max_seq_len+1, len(s.labels)) + + tokens[i, :text_len] = s.tokens[:text_len] + labels[i, :target_len] = s.labels[:target_len] + + num_tiles = torch.tensor([n for s in samples for n in s.num_tiles], dtype=torch.int32) + if len(num_tiles) == 0: + num_tiles = torch.tensor([[0]], dtype=torch.int32) + + # Cumulative sample lengths are needed for packing, otherwise use dummy values. + cu_lengths = torch.tensor([[0]], dtype=torch.int32) + max_lengths = torch.tensor([[0]], dtype=torch.int32) + + if self.is_packing_enabled: + cu_lengths = torch.stack([s.cu_lengths for s in samples]) + max_lengths = torch.tensor([s.max_length for s in samples], dtype=torch.int32) + + return ImageTaskBatchPacked( + __key__=[s.__key__ for s in samples], + __restore_key__=[s.__restore_key__ for s in samples], + __subflavor__=None, + __subflavors__=samples[0].__subflavors__, + tokens=tokens, + labels=labels, + imgs=imgs, + num_tiles=num_tiles, + cu_lengths=cu_lengths, + max_lengths=max_lengths, + ) + + def encode_batch(self, batch: ImageTaskBatchPacked) -> dict: + raw = dataclasses.asdict(batch) + del raw["__subflavors__"] + return raw + + def select_samples_to_pack(self, samples: List[ImageTaskSample]) -> List[List[ImageTaskSample]]: + """Selects which samples will be packed together. + + NOTE: Energon dataloader calls this method internally if packing is used. + Please see https://nvidia.github.io/Megatron-Energon/packing.html + """ + lengths = [sample.total_len for sample in samples] + + packed_samples = greedy_knapsack(lengths, samples, self.packing_seq_length) + + return packed_samples + + @stateless + def pack_selected_samples(self, samples: List[ImageTaskSample]) -> List[ImageTaskSamplePacked]: + """ + Function to pack a list of ImageTaskSample into a single ImageTaskSamplePacked. + + NOTE: Energon dataloader calls this method internally if packing is used. + Please see https://nvidia.github.io/Megatron-Energon/packing.html + + Args: + samples: List of ImageTaskSample instances to pack into one sample. + + Returns: + ImageTaskSamplePacked instance. + """ + packing_seq_len = self.packing_seq_length + + packed_tokens = [] + packed_labels = [] + packed_imgs = [] + + current_length = 0 + max_length = 0 + cu_lengths = [0] + + # Process each sample and build lists that we will concatenate to create the packed sample. + for _, sample in enumerate(samples): + sample_len = sample.total_len + + if sample_len > max_length: + max_length = sample_len + + # If adding this sample exceeds the max length, stop. + # This should not happen. The select_samples_to_pack method should have already ensured that the samples fit. + if current_length + sample_len > packing_seq_len: + raise ValueError(f"Packed sample exceeds the maximum sequence length of {packing_seq_len}: {samples}") + + # Add the sample's tokens and labels + packed_tokens.append(sample.tokens) + packed_labels.append(sample.labels) + + # Add the images + packed_imgs += sample.imgs + + current_length += sample_len + cu_lengths.append(current_length) + + # Concatenate packed tokens and labels. + packed_tokens = torch.cat(packed_tokens, dim=0) + packed_labels = torch.cat(packed_labels, dim=0) + + return ImageTaskSamplePacked( + __key__=",".join([s.__key__ for s in samples]), + __restore_key__=(), # Will be set by energon based on `samples` + __subflavor__=None, + __subflavors__=samples[0].__subflavors__, + tokens=packed_tokens, + labels=packed_labels, + imgs=packed_imgs, + cu_lengths=torch.tensor(cu_lengths, dtype=torch.int32), + max_length=max_length, + num_tiles=[n for s in samples for n in s.num_tiles], + ) + + +def print_error_handler(exc: Exception, key: Optional[str]): + print( + f"The following exception occurred in the dataloader for sample {key} and is skipped", + file=sys.stderr, + ) + traceback.print_exc() + + +def format_multichoice_question(question, multichoice_options): + """Format multi-choice question.""" + options_text = ["{}. {}\n".format(chr(ord('A') + i), option) for i, option in + zip(range(len(multichoice_options)), multichoice_options)] + options_text = "".join(options_text) + + options_text = f"{options_text}Answer with the option's letter from the given choices directly." + + return "{}\n{}".format(question, options_text) + + +def format_multichoice_answer(idx): + """Format multi-choice answer.""" + return chr(ord('A') + idx) diff --git a/examples/multimodal/energon_util.py b/examples/multimodal/energon_util.py new file mode 100644 index 0000000000..661e691c53 --- /dev/null +++ b/examples/multimodal/energon_util.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import torch +import warnings +from dataclasses import dataclass +from typing import Any, List + +from megatron.energon import Sample +from megatron.energon.epathlib.epath import EPath +from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory + + +@dataclass +class SampleListSample(Sample): + """Sample type for a list of samples of any type which needs to be packed together. + + This is useful for datasets which are packed offline. + """ + + #: The images of the sequence + samples: List[Any] + + +class SampleListWebdataset(DefaultDecoderWebdatasetFactory[SampleListSample]): + __sample_type__ = SampleListSample + + def __init__(self, path: EPath, **kwargs): + warnings.warn( + f"{type(self)} is deprecated, use the default instead and set the sample_type:\n" + f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n" + f"# remove top-level __module__ and __class__\n" + f"sample_type:\n" + f" __module__: megatron.energon\n" + f" __class__: {self.__sample_type__.__name__}\n" + f"# Keep the remaining content", + DeprecationWarning, + ) + super().__init__(path, **kwargs) + + +@dataclass +class OfflineTargetAspectRatioSample(Sample): + """Sample type for image + text samples with target aspect ratio computed offline.""" + + #: The images of the sequence + images: List[torch.Tensor] + #: The texts of the sequence + texts: List[str] + target_aspect_ratio: List[List] diff --git a/examples/multimodal/evaluation/evaluate_ai2d.py b/examples/multimodal/evaluation/evaluate_ai2d.py new file mode 100644 index 0000000000..39b866ae4a --- /dev/null +++ b/examples/multimodal/evaluation/evaluate_ai2d.py @@ -0,0 +1,52 @@ +import argparse +import json + +from evaluate_mmmu import get_input_output_paths +from evaluate_vqav2 import compute_vqa_accuracy + + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="AI2D") + + results = dict() + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + sample_id = res["sample_id"] + + # Ignore possible duplicates. + if sample_id in results: + continue + + results[sample_id] = { + "question_id": sample_id, + "answer": res["answer"], + "gt_answer": res["gt_answer"], + } + + results = list(results.values()) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def ai2d_eval(input_path): + """Run AI2D evaluation.""" + result_file_path = merge_input_files(input_path) + avg_acc = compute_vqa_accuracy(result_file_path, task="AI2D") + return avg_acc + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + avg_acc = ai2d_eval(args.input_path) + + print(f"===== AI2D Accuracy {avg_acc:.2f}% =====") diff --git a/examples/multimodal/evaluation/evaluate_chartqa.py b/examples/multimodal/evaluation/evaluate_chartqa.py new file mode 100644 index 0000000000..53d4944f46 --- /dev/null +++ b/examples/multimodal/evaluation/evaluate_chartqa.py @@ -0,0 +1,48 @@ +import argparse +import json + +from evaluate_mmmu import get_input_output_paths +from evaluate_vqav2 import compute_vqa_accuracy + + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="ChartQA") + + results = dict() + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + sample_id = res["sample_id"] + + # Ignore possible duplicates. + if sample_id in results: + continue + + res["question_id"] = sample_id + results[sample_id] = res + + results = list(results.values()) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def chartqa_eval(input_path): + """Run ChartQA evaluation.""" + result_file_path = merge_input_files(input_path) + return compute_vqa_accuracy(result_file_path, task="ChartQA") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + avg_acc = chartqa_eval(args.input_path) + + print(f"ChartQA accuracy: {avg_acc:.2f}") diff --git a/examples/multimodal/evaluation/evaluate_coco.py b/examples/multimodal/evaluation/evaluate_coco.py new file mode 100644 index 0000000000..8eeb367e8f --- /dev/null +++ b/examples/multimodal/evaluation/evaluate_coco.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import argparse +import json + +from evaluate_mmmu import get_input_output_paths +from pycocoevalcap.eval import COCOEvalCap +from pycocotools.coco import COCO + + +def convert_to_coco_format(input_path): + """Convert input files to COCO compatible format.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="captioning") + + results = dict() + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + sample_id = res["sample_id"] + + # Ignore possible duplicates. + if sample_id in results: + continue + + caption = res["caption"].rstrip(".").lower() + results[sample_id] = { + "image_id": sample_id, + "caption": caption, + } + + results = list(results.values()) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file, indent=4) + + return output_file_path + + +def coco_captioning_eval(input_path, groundtruth_file): + """Run COCO captioning evaluation.""" + coco = COCO(groundtruth_file) + input_file = convert_to_coco_format(input_path) + coco_result = coco.loadRes(input_file) + + coco_eval = COCOEvalCap(coco, coco_result) + + # Evaluate on the input subset of images. + coco_eval.params["image_id"] = coco_result.getImgIds() + + coco_eval.evaluate() + + print("========== COCO captioning scores ==========") + for metric, score in coco_eval.eval.items(): + print(f"{metric} {score * 100:.3f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-path", type=str, required=True, help="Path to input file(s)") + parser.add_argument( + "--groundtruth-path", type=str, required=True, help="Path to groundtruth file" + ) + args = parser.parse_args() + + coco_captioning_eval(args.input_path, args.groundtruth_path) diff --git a/examples/multimodal/evaluation/evaluate_mathvista.py b/examples/multimodal/evaluation/evaluate_mathvista.py new file mode 100644 index 0000000000..a55f312f21 --- /dev/null +++ b/examples/multimodal/evaluation/evaluate_mathvista.py @@ -0,0 +1,122 @@ +import argparse +import json +import re + +from evaluate_mmmu import get_input_output_paths +from MMMU.mmmu.utils.eval_utils import parse_multi_choice_response +from open_flamingo.eval.vqa_metric import VQAEval + + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="MathVista") + + results = dict() + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + sample_id = res["sample_id"] + + # Remove possible duplicates. + if sample_id in results: + continue + + results[sample_id] = res + + results = list(results.values()) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def extra_processing(text): + """Extra processing.""" + # Max decimal point capped to 2 decimal point + regex = re.compile(r'^\d+\.\d+$') + decimal = regex.findall(text) + + if len(decimal) > 0: + non_decimal = len(decimal[0].split(".")[0]) + + # if decimal values are all 0, trim them + decimal_digits = [int(d) for d in decimal[0].split(".")[1]] + if sum(decimal_digits) == 0: + text = decimal[0][:non_decimal] + else: + text = decimal[0][: non_decimal + 3] + + # remove % and trailing . + text = text.replace("%", "") + if text[-1] == ".": + text = text[:-1] + + return text + + +def extract_answer(text): + """Extract answer.""" + alphabet = re.findall(r'[a-zA-Z]+', text) + if len(alphabet) > 0 and "e+" not in text: + template = re.findall(r'answer is -*\d+\.*\d*', text) + if len(template) > 0: + text = template[0] + + numbers = re.findall(r'-*\d+\.*\d*', text) + text = numbers[0] if len(numbers) > 0 else text + + return text + + +def compute_mathvista_accuracy(result_file): + """Compute MathVista accuracy.""" + merged_results = json.load(open(result_file)) + + vqa = VQAEval(vqa=None, vqaRes=None) + acc = 0 + for res in merged_results: + pred_ans = res["answer"] + if res["question_type"] == "multi_choice": + pred_ans = parse_multi_choice_response(pred_ans, res["all_choices"], res["index2ans"]) + else: + pred_ans = vqa.processPunctuation(pred_ans) + pred_ans = vqa.processDigitArticle(pred_ans) + # Extra processing and extraction. + pred_ans = extra_processing(pred_ans) + pred_ans = extract_answer(pred_ans) + + gt_ans = res["gt_answer"] + if isinstance(gt_ans, list): + assert len(gt_ans) == 1, f"Expected 1 groundtruth, got {gt_ans}" + gt_ans = gt_ans[0] + + if res["question_type"] != "multi_choice": + gt_ans = vqa.processPunctuation(gt_ans) + gt_ans = vqa.processDigitArticle(gt_ans) + + gt_ans = extra_processing(gt_ans) + + if pred_ans == gt_ans: + acc += 1 + acc = acc / len(merged_results) * 100 + return acc + + +def mathvista_eval(input_path): + """Run MathVista evaluation.""" + result_file_path = merge_input_files(input_path) + acc = compute_mathvista_accuracy(result_file_path) + return acc + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + acc = mathvista_eval(args.input_path) + + print(f"===== MathVista accuracy: {acc} =====") diff --git a/examples/multimodal/evaluation/evaluate_mmmu.py b/examples/multimodal/evaluation/evaluate_mmmu.py new file mode 100644 index 0000000000..798c42bfa7 --- /dev/null +++ b/examples/multimodal/evaluation/evaluate_mmmu.py @@ -0,0 +1,116 @@ +import argparse +import glob +import json +import os +import sys +import re +import subprocess + +# Get the absolute path of the parent directory +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) +# Add the parent directory to sys.path +sys.path.insert(0, parent_dir) + +from run_text_generation import get_output_path +from config import EvaluationConfig + + +def get_input_output_paths(input_path, task): + """Get all input files and an output path for a merged file.""" + # Single input file. + if os.path.exists(input_path): + input_file_paths = [input_path] + output_file_path = input_path.replace(".jsonl", "-merged.json") + # Select multiple partitions and dp ranks. + else: + cfg = EvaluationConfig(task=task, output_path=input_path, partition_id="*") + pattern = get_output_path(cfg, dp_rank="*") + input_file_paths = glob.glob(pattern) + + output_file_path = input_path + f"-{task}-merged.json" + + return input_file_paths, output_file_path + + +def convert_to_mmmu_format(input_path): + """Convert input files to MMMU compatible format.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, "MMMU") + + output = dict() + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + + sample_id = res["sample_id"] + prediction = res["prediction"] + + if res["question_type"] == "multiple-choice": + from MMMU.mmmu.utils.eval_utils import parse_multi_choice_response + + prediction = parse_multi_choice_response( + prediction, res["all_choices"], res["index2ans"] + ) + + # MMMU eval script expects just a sample_id to prediction mapping. + # Skip possible duplicates. + if sample_id in output: + continue + + output[sample_id] = prediction + + with open(output_file_path, "w") as output_file: + json.dump(output, output_file) + + return output_file_path + + +def mmmu_eval(input_path, groundtruth_path): + """Run MMMU evaluation.""" + result_file = convert_to_mmmu_format(input_path) + + # The MMMU repo has a script for running the actual evaluation but no API. So launching the script here. + output = subprocess.run( + [ + "python", + "examples/multimodal/MMMU/mmmu/main_eval_only.py", + "--output_path", + result_file, + "--answer_path", + groundtruth_path, + ], + capture_output=True, + text=True, + ) + + print(output.stderr) + print(output.stdout) + + m = re.search("'Overall': {'num': \d+, 'acc': (\d.\d+)}", output.stdout) + + return float(m.group(1)) * 100.0 + + +def main(): + """Run MMMU evaluation.""" + # Using the validation groundtruth file from the MMMU repo by default. This assumes you have cloned the MMMU github repo here. + default_groundtruth_path = "examples/multimodal/MMMU/mmmu/answer_dict_val.json" + + parser = argparse.ArgumentParser() + parser.add_argument("--input-path", type=str, required=True, help="Path to input file(s)") + parser.add_argument( + "--groundtruth-path", + type=str, + default=default_groundtruth_path, + help="Path to groundtruth file. Defaults to the validation file in the MMMU repo.", + ) + args = parser.parse_args() + + avg_acc = mmmu_eval(args.input_path, args.groundtruth_path) + + print(f"MMMU average accuracy: {avg_acc:.2f}") + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/evaluation/evaluate_ocrbench.py b/examples/multimodal/evaluation/evaluate_ocrbench.py new file mode 100644 index 0000000000..b37473a67d --- /dev/null +++ b/examples/multimodal/evaluation/evaluate_ocrbench.py @@ -0,0 +1,137 @@ +import argparse +import json + +from evaluate_mmmu import get_input_output_paths + + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="OCRBench") + + results = dict() + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + sample_id = res["sample_id"] + + # Remove possible duplicates. + if sample_id in results: + continue + + results[sample_id] = res + + results = list(results.values()) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def compute_ocrbench_score(result_file): + """Compute OCRBench score.""" + merged_results = json.load(open(result_file)) + + # OCRBench score calculation is adopted from https://github.com/Yuliang-Liu/MultimodalOCR/blob/1b7713f44c91f30f64efb6d3e494c416861ef15f/example.py#L1 + # MIT License. Copyright (c) 2023 Yuliang Liu + score = { + "Regular Text Recognition": 0, + "Irregular Text Recognition": 0, + "Artistic Text Recognition": 0, + "Handwriting Recognition": 0, + "Digit String Recognition": 0, + "Non-Semantic Text Recognition": 0, + "Scene Text-centric VQA": 0, + "Doc-oriented VQA": 0, + "Doc-oriented VQA": 0, + "Key Information Extraction": 0, + "Handwritten Mathematical Expression Recognition": 0, + } + + for res in merged_results: + predict = res["answer"] + answers = res["gt_answer"] + + dataset_name = res["dataset_name"] + ocr_type = res["data_type"] + + if dataset_name == "HME100k": + if isinstance(answers, list): + for j in range(len(answers)): + answer = answers[j].strip().replace("\n", " ").replace(" ", "") + predict = predict.strip().replace("\n", " ").replace(" ", "") + if answer in predict: + score[ocr_type] += 1 + else: + answers = answers.strip().replace("\n", " ").replace(" ", "") + predict = predict.strip().replace("\n", " ").replace(" ", "") + if answers in predict: + score[ocr_type] += 1 + else: + if isinstance(answers, list): + for j in range(len(answers)): + answer = answers[j].lower().strip().replace("\n", " ") + predict = predict.lower().strip().replace("\n", " ") + if answer in predict: + score[ocr_type] += 1 + else: + answers = answers.lower().strip().replace("\n", " ") + predict = predict.lower().strip().replace("\n", " ") + if answers in predict: + score[ocr_type] += 1 + + recognition_score = ( + score['Regular Text Recognition'] + + score['Irregular Text Recognition'] + + score['Artistic Text Recognition'] + + score['Handwriting Recognition'] + + score['Digit String Recognition'] + + score['Non-Semantic Text Recognition'] + ) + final_score = ( + recognition_score + + score['Scene Text-centric VQA'] + + score['Doc-oriented VQA'] + + score['Key Information Extraction'] + + score['Handwritten Mathematical Expression Recognition'] + ) + result_log = f"""###########################OCRBench############################## +Text Recognition(Total 300): {recognition_score} +------------------Details of Recognition Score------------------- +Regular Text Recognition(Total 50): {score['Regular Text Recognition']} +Irregular Text Recognition(Total 50): {score['Irregular Text Recognition']} +Artistic Text Recognition(Total 50): {score['Artistic Text Recognition']} +Handwriting Recognition(Total 50): {score['Handwriting Recognition']} +Digit String Recognition(Total 50): {score['Digit String Recognition']} +Non-Semantic Text Recognition(Total 50): {score['Non-Semantic Text Recognition']} +---------------------------------------------------------------- +Scene Text-centric VQA(Total 200): {score['Scene Text-centric VQA']} +---------------------------------------------------------------- +Doc-oriented VQA(Total 200): {score['Doc-oriented VQA']} +---------------------------------------------------------------- +Key Information Extraction(Total 200): {score['Key Information Extraction']} +---------------------------------------------------------------- +Handwritten Mathematical Expression Recognition(Total 100): {score['Handwritten Mathematical Expression Recognition']} +----------------------Final Score------------------------------- +Final Score(Total 1000): {final_score}""" + + return result_log, final_score + + +def ocrbench_eval(input_path): + """Run OCRBench evaluation.""" + result_file_path = merge_input_files(input_path) + result_log, score = compute_ocrbench_score(result_file_path) + return result_log, score + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + result_log, _ = ocrbench_eval(args.input_path) + + print(result_log) diff --git a/examples/multimodal/evaluation/evaluate_textvqa.py b/examples/multimodal/evaluation/evaluate_textvqa.py new file mode 100644 index 0000000000..af782bdf03 --- /dev/null +++ b/examples/multimodal/evaluation/evaluate_textvqa.py @@ -0,0 +1,52 @@ +import argparse +import json + +from evaluate_mmmu import get_input_output_paths +from evaluate_vqav2 import compute_vqa_accuracy + + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="TextVQA") + + results = dict() + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + sample_id = res["sample_id"] + + # Remove possible duplicates. + if sample_id in results: + continue + + results[sample_id] = { + "question_id": sample_id, + "answer": res["answer"], + "gt_answer": res["gt_answer"], + } + + results = list(results.values()) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def textvqa_eval(input_path): + """Run TextVQA evaluation.""" + result_file_path = merge_input_files(input_path) + avg_acc = compute_vqa_accuracy(result_file_path, task="TextVQA") + return avg_acc + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + avg_acc = textvqa_eval(args.input_path) + + print(f"===== TextVQA Accuracy {avg_acc:.2f}% =====") diff --git a/examples/multimodal/evaluation/evaluate_vqav2.py b/examples/multimodal/evaluation/evaluate_vqav2.py new file mode 100644 index 0000000000..7807d80723 --- /dev/null +++ b/examples/multimodal/evaluation/evaluate_vqav2.py @@ -0,0 +1,109 @@ +import argparse +import json + +from evaluate_mmmu import get_input_output_paths +from open_flamingo.eval.vqa_metric import VQAEval + + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="VQAv2") + + results = dict() + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + sample_id = res["sample_id"] + + # Skip possible duplicates. + if sample_id in results: + continue + + res["question_id"] = sample_id + results[sample_id] = res + + results = list(results.values()) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def is_number(n: str): + """Check if input is a number.""" + try: + float(n) + return True + except ValueError: + return False + + +def compute_vqa_accuracy(result_file, task): + """Compute VQA accuracy.""" + merged_results = json.load(open(result_file)) + + vqa = VQAEval(vqa=None, vqaRes=None) + all_acc = [] + for res in merged_results: + pred = res["answer"] + pred = vqa.processPunctuation(pred) + pred = vqa.processDigitArticle(pred) + + gt = res["gt_answer"] + gt = [vqa.processPunctuation(ans) for ans in gt] + gt = [vqa.processDigitArticle(ans) for ans in gt] + + # ChartQA uses relaxed accuracy: + # "We consider an answer to be correct if it is within 5% of the gold answer. + # For non-numeric answers, we still need an exact match to consider an answer to be correct." + if task == "ChartQA": + acc = 0.0 + assert len(gt) == 1, "expected exactly one groundtruth answer." + gt = gt[0] + + pred = pred.rstrip("%") + gt = gt.rstrip("%") + + if is_number(pred) and is_number(gt): + pred = float(pred) + gt = float(gt) + if pred >= (gt * 0.95) and pred <= (gt * 1.05): + acc = 1.0 + elif pred == gt: + acc = 1.0 + + all_acc.append(acc) + elif task in ("VQAv2", "TextVQA"): + num_match = sum([pred == ans for ans in gt]) + acc = min(1.0, num_match / 3.0) + all_acc.append(acc) + elif task == "AI2D": + assert len(gt) == 1, f"Expected exactly 1 GT, got {gt}" + acc = pred == gt[0] + all_acc.append(acc) + else: + raise NotImplementedError(f"unknown task {task}") + + acc_avg = sum(all_acc) / len(all_acc) * 100 + + return acc_avg + + +def vqav2_eval(input_path): + """Run VQAv2 evaluation.""" + result_file = merge_input_files(input_path) + avg_acc = compute_vqa_accuracy(result_file, task="VQAv2") + return avg_acc + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + avg_acc = vqav2_eval(args.input_path) + + print(f"===== VQAv2 Accuracy {avg_acc:.2f}% =====") diff --git a/examples/multimodal/evaluation/evaluation_datasets.py b/examples/multimodal/evaluation/evaluation_datasets.py new file mode 100644 index 0000000000..dc7ebf3c30 --- /dev/null +++ b/examples/multimodal/evaluation/evaluation_datasets.py @@ -0,0 +1,914 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Evaluation datasets.""" +import glob +import itertools +import json +import os +import re +from collections import defaultdict + +import numpy as np +import torch +from image_processing import get_visual_transform +from PIL import Image + +from megatron.training import print_rank_0 + + +def _get_partition_bounds( + total_num_samples, num_samples_per_partition, num_partitions, partition_id +): + if num_samples_per_partition == 0: + samples_per_partition = [ + int(x) for x in np.linspace(0, total_num_samples, num_partitions + 1) + ] + return samples_per_partition[partition_id], samples_per_partition[partition_id + 1] + return num_samples_per_partition * partition_id, num_samples_per_partition * (partition_id + 1) + + +class VQADataset(torch.utils.data.Dataset): + """VQA evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ): + samples = json.load(open(gt_path, encoding='utf-8')) + if "data" in samples: + samples = samples["data"] + + # Optionally, process only a subset of the input files. + if num_partitions > 0: + lb, ub = _get_partition_bounds( + len(samples), num_samples_per_partition, num_partitions, partition_id + ) + samples = samples[lb:ub] + + self._keys = keys + self._samples = samples + self._input_image_path = input_image_path + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._samples) + + def __getitem__(self, idx): + sample = self._samples[idx] + + img_file = "{}/{}".format(self._input_image_path, sample[self._keys["image_id"]]) + if not os.path.exists(img_file): + img_file += ".jpg" + + if not os.path.exists(img_file): + img_file = img_file.replace('.jpg', '.png') + + img = Image.open(img_file) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + sample_id = idx + if "sample_id" in self._keys: + sample_id = sample[self._keys["sample_id"]] + + metadata = "" # Not used. + + return ( + torch.stack(imgs), + tile_count, + sample_id, + sample[self._keys["question"]], + sample[self._keys["answer"]], + metadata, + ) + + +class CaptioningDataset(torch.utils.data.Dataset): + """Captioning evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ): + image_files = sorted(glob.glob(input_image_path + "/*")) + + # Optionally, process only a subset of the input files. + if num_partitions > 0: + lb, ub = _get_partition_bounds( + len(image_files), num_samples_per_partition, num_partitions, partition_id + ) + image_files = image_files[lb:ub] + + gts = json.load(open(gt_path)) + answers = defaultdict(list) + for gt in gts["annotations"]: + answers[gt["image_id"]].append(gt['caption']) + + self._image_files = image_files + self._answers = answers + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._image_files) + + def __getitem__(self, idx): + img_file = self._image_files[idx] + image_id = int(img_file.split("_")[-1].split(".")[0]) + + img = Image.open(img_file) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + question = "" # Fixed for all samples. + metadata = "" # Not used. + + return torch.stack(imgs), tile_count, image_id, question, self._answers[image_id], metadata + + +class MMMUDataset(torch.utils.data.Dataset): + """MMMU evaluation dataset.""" + + def __init__( + self, + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + prompt_style, + vision_model_type, + ): + import datasets + from MMMU.mmmu.utils.data_utils import CAT_SHORT2LONG, load_yaml + + # The following downloads the MMMU dataset from HuggingFace and uses the API from the MMMU github repo to run MMMU evaluation. + all_mmmu_datasets = [] + + hf_datasets_cache = os.environ["HF_DATASETS_CACHE"] + assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE." + + for subject in CAT_SHORT2LONG.values(): + # Use a local copy of the dataset if exists (can be faster) or the HF one. + if os.path.exists(input_image_path): + subject_dataset = datasets.load_dataset( + os.path.join(input_image_path, subject), + split=datasets.Split.VALIDATION, + cache_dir=hf_datasets_cache, + verification_mode="no_checks", + ) + else: + subject_dataset = datasets.load_dataset( + "MMMU/MMMU", + subject, + split=datasets.Split.VALIDATION, + cache_dir=hf_datasets_cache, + ) + + all_mmmu_datasets.append(subject_dataset) + + dataset = datasets.concatenate_datasets(all_mmmu_datasets) + + dataset = [s for s in dataset if s['id'].startswith("val")] + + # Optionally, process only a subset of the input files. + if num_partitions > 0: + lb, ub = _get_partition_bounds( + len(dataset), num_samples_per_partition, num_partitions, partition_id + ) + dataset = dataset[lb:ub] + + # Using the LLaVA config from the MMMU repo. + config = load_yaml("examples/multimodal/MMMU/mmmu/configs/llava1.5.yaml") + for k, v in config.items(): + if isinstance(v, list): + assert len(v) == 1, "only one value supported." + config[k] = v[0] + + self._config = config + + self._dataset = dataset + + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._prompt_style = prompt_style + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, idx): + from MMMU.mmmu.utils.data_utils import construct_prompt, process_single_sample + + sample = self._dataset[idx] + + # Use the single image approach from the MMMU repo. + if self._prompt_style == "single_image": + sample = process_single_sample(sample) + sample = construct_prompt(sample, self._config) + + img = sample["image"] + sample_imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + sample_num_tiles = [len(sample_imgs)] + + prompt = sample["final_input_prompt"] + for i in range(8): + prompt = prompt.replace(f"", "") + sample["final_input_prompt"] = f"\n{prompt}" + elif self._prompt_style == "vlmevalkit": + sample = construct_prompt(sample, self._config) + + if sample["question_type"] == "multiple-choice": + question = sample["question"] + + options = "" + for k, v in sample["index2ans"].items(): + options += f"{k}. {v}\n" + + final_prompt = f"{question}\n" + if "hint" in sample: + final_prompt += f"Hint: {sample['hint']}\n" + + if "task_instructions" in sample: + final_prompt += f"Task instructions: {sample['task_instructions']}\n" + + final_prompt += options + final_prompt += "Answer with the option's letter from the given choices directly." + + sample["final_input_prompt"] = final_prompt.rstrip() + else: + question = sample["question"] + final_prompt = f"{question}\n" + final_prompt += "Answer the question directly." + sample["final_input_prompt"] = final_prompt.rstrip() + + sample_imgs = [] + sample_num_tiles = [] + + img_indices = sorted(list(set(re.findall(r"" + + img = sample[img_key] + assert img is not None, f"{img_str} is in prompt but not in sample images" + + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + adjusted_max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) # List of tiles. + + sample_imgs.extend(imgs) + sample_num_tiles.append(len(imgs)) + + sample["final_input_prompt"] = " ".join([f'' for i in range(len(img_indices))]) + "\n" + sample["final_input_prompt"] + elif self._prompt_style == "multi_image": + sample = construct_prompt(sample, self._config) + + sample_imgs = [] + sample_num_tiles = [] + + img_indices = re.findall(r"" + + img = sample[img_key] + assert img is not None, f"{img_str} is in prompt but not in sample images" + + # Note: Only replace the current image tag. + sample["final_input_prompt"] = sample["final_input_prompt"].replace( + img_str, "", 1 + ) + + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + adjusted_max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) # List of tiles. + + sample_imgs.extend(imgs) + sample_num_tiles.append(len(imgs)) + + # Sanity check. + for i in range(1, 8): + assert ( + f"" not in sample["final_input_prompt"] + ), "prompt contains unhandled image tags" + else: + raise ValueError(f"unknown prompt style {self._prompt_style}") + + # MMMU specific metadata. + metadata = {"question_type": sample["question_type"]} + if sample["question_type"] == "multiple-choice": + metadata["index2ans"] = sample["index2ans"] + metadata["all_choices"] = sample["all_choices"] + + prompt = sample['final_input_prompt'] + + tile_count = torch.tensor(sample_num_tiles, dtype=torch.int) + + return ( + torch.stack(sample_imgs), + tile_count, + sample["id"], + prompt, + sample["answer"], + metadata, + ) + + +class VideoMMEDataset(torch.utils.data.Dataset): + "Video MME evaluation dataset." + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_frames, + vision_model_type, + ): + ground_truth_original = json.load(open(gt_path)) + ground_truth = [] + for gt in ground_truth_original: + video_path = gt["url"] + video_path = video_path.replace("https://www.youtube.com/watch?v=", "") + video_path = video_path.replace("https://m.youtube.com/watch?v=", "") + video_path = os.path.join(input_image_path, video_path + ".mp4") + if not os.path.exists(video_path): + continue + gt["video_path"] = video_path + ground_truth.append(gt) + + ground_truth = sorted(ground_truth, key=lambda gt: gt["video_path"]) + print_rank_0(f"Found {len(ground_truth)} videos to process.") + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(ground_truth), num_samples_per_partition, num_partitions, partition_id + ) + ground_truth = ground_truth[start_idx:end_idx] + + self._ground_truth = ground_truth + self._img_h = img_h + self._img_w = img_w + self._use_tiling = False + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._num_frames = num_frames + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._ground_truth) + + def __getitem__(self, idx): + from torchvision.io import read_video + + gt = self._ground_truth[idx] + + video, _, _ = read_video(gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec') + video = video.numpy() + selected_frames = torch.linspace(0, video.shape[0] - 1, self._num_frames).long() + video_frames = video[selected_frames] + if self._num_frames == 1: + video_frames = video_frames[None] + + imgs = [] + for img in video_frames: + from torchvision.transforms import ToPILImage + to_pil = ToPILImage() + img = to_pil(img) + imgs += get_visual_transform( + img, self._img_h, self._img_w, self._use_tiling, self._max_num_tiles, + self._use_thumbnail, augment=False, vision_model_type=self._vision_model_type + ) + + for question in gt["questions"]: + # Very hacky, but we essentially re-create gt holding only the + # question of interest. This is the make this generation script + # compatible with the Video MME evaluation script. + question_dict = { + "video_id": gt["video_id"], + "duration_category": gt["duration_category"], + "video_category": gt["video_category"], + "video_subcategory": gt["video_subcategory"], + "url": gt["url"], + "questions": [question], + } + + num_tiles = torch.tensor([len(imgs)], dtype=torch.int) + + answer = "" + metadata = "" + + return ( + torch.stack(imgs), + num_tiles, + question["question_id"], + question_dict, + answer, + metadata, + ) + + +class OCRBenchDataset(torch.utils.data.Dataset): + """OCRBench evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ): + gt = json.load(open(gt_path, encoding='utf-8')) + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(gt), num_samples_per_partition, num_partitions, partition_id + ) + gt = gt[start_idx:end_idx] + + self._input_image_path = input_image_path + self._gt = gt + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._gt) + + def __getitem__(self, idx): + img_path = os.path.join(self._input_image_path, self._gt[idx]['image_path']) + + img = Image.open(img_path) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + metadata = { + "dataset_name": self._gt[idx]["dataset_name"], + "data_type": self._gt[idx]["type"], + } + + return ( + torch.stack(imgs), + tile_count, + idx, + self._gt[idx]["question"], + self._gt[idx]["answers"], + metadata, + ) + + +class MathVistaDataset(torch.utils.data.Dataset): + """MathVista evaluation dataset.""" + + def __init__( + self, + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ): + import datasets + + hf_datasets_cache = os.environ["HF_DATASETS_CACHE"] + assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE." + + if os.path.exists(input_image_path): + dataset = datasets.load_dataset( + input_image_path, cache_dir=hf_datasets_cache, verification_mode="no_checks" + ) + else: + dataset = datasets.load_dataset( + "AI4Math/MathVista", split="testmini", cache_dir=hf_datasets_cache + ) + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(dataset), num_samples_per_partition, num_partitions, partition_id + ) + dataset = dataset[start_idx:end_idx] + + self._dataset = dataset + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._dataset["pid"]) + + def __getitem__(self, idx): + # Already a PIL object. + img = self._dataset['decoded_image'][idx] + + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + question_id = self._dataset["pid"][idx] + question = self._dataset["question"][idx] + question_type = self._dataset["question_type"][idx] # free_form or multi_choice + query = self._dataset["query"][idx] + choices = self._dataset["choices"][idx] + answer = self._dataset["answer"][idx] + + if question_type == 'multi_choice': + start_chr = 'A' + choices_str = '' + index2ans = {} + all_choices = [] + for choice in choices: + all_choices.append(start_chr) + index2ans[start_chr] = choice + choices_str += f"{start_chr}. {choice}\n" + start_chr = chr(ord(start_chr) + 1) + + question = question + '\n' + choices_str + question = question + "Answer with the option's letter from the given choices directly." + answer = chr(ord('A') + choices.index(answer)) + else: + question = query.replace("Hint: ", "") + index2ans = {} + all_choices = [] + + metadata = { + "question_type": question_type, + "index2ans": index2ans, + "all_choices": all_choices, + } + + return torch.stack(imgs), tile_count, question_id, question, answer, metadata + + +class AI2DDataset(torch.utils.data.Dataset): + """AI2D evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + no_mask, + vision_model_type, + ): + with open(gt_path, 'r') as f: + jsonl = list(f) + + gt = [json.loads(json_str) for json_str in jsonl] + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(gt), num_samples_per_partition, num_partitions, partition_id + ) + gt = gt[start_idx:end_idx] + + self._gt = gt + self._input_image_path = input_image_path + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._no_mask = no_mask + self._vision_model_type = vision_model_type + + def __len__(self): + return len(self._gt) + + def __getitem__(self, idx): + img_path = os.path.join(self._input_image_path, self._gt[idx]['image']) + if self._no_mask: + img_path.replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES") + + img = Image.open(img_path) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + metadata = "" # Not used. + + return ( + torch.stack(imgs), + tile_count, + self._gt[idx]["question_id"], + self._gt[idx]["question"], + self._gt[idx]["answer"], + metadata, + ) + + +def get_evaluation_dataset( + task, + input_image_path, + gt_path, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_samples_per_partition, + num_partitions, + partition_id, + num_frames, + vision_model_type, +): + """Get an evaluation dataset.""" + if task == "TextVQA": + keys = { + "image_id": "image_id", + "sample_id": "question_id", + "question": "question", + "answer": "answers", + } + + dataset = VQADataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == "VQAv2": + keys = { + "image_id": "image", + "sample_id": "question_id", + "question": "question", + "answer": "answer", + } + + dataset = VQADataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == "ChartQA": + keys = {"image_id": "imgname", "question": "query", "answer": "label"} + + dataset = VQADataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == "captioning": + dataset = CaptioningDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == 'MMMU': + # Note: + # - prompt_style="single_image" uses only one image like in the MMMU repo example. + # - prompt_style="multi_image" uses multiple input images. + # - prompt_style="vlmevalkit" is similar to https://github.com/open-compass/VLMEvalKit/blob/5d3cebcf18ef4bfbadc3bd3ef80bdc7aad2c6557/vlmeval/vlm/internvl_chat.py#L499 + dataset = MMMUDataset( + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + prompt_style="single_image", + vision_model_type=vision_model_type, + ) + elif task == "VideoMME": + dataset = VideoMMEDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_frames, + vision_model_type, + ) + elif task == "OCRBench": + dataset = OCRBenchDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == "MathVista": + dataset = MathVistaDataset( + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + vision_model_type, + ) + elif task == "AI2D": + dataset = AI2DDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + no_mask=False, + vision_model_type=vision_model_type, + ) + else: + raise NotImplementedError(f"unsupported task {task}") + + return dataset diff --git a/examples/multimodal/image_processing.py b/examples/multimodal/image_processing.py new file mode 100644 index 0000000000..15175a634e --- /dev/null +++ b/examples/multimodal/image_processing.py @@ -0,0 +1,143 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE. +from torchvision import transforms as T +from torchvision.transforms import Compose +from torchvision.transforms.functional import InterpolationMode + + +IMAGENET_PIXEL_MEAN = [0.485, 0.456, 0.406] +IMAGENET_PIXEL_STD = [0.229, 0.224, 0.225] +SIGLIP_PIXEL_MEAN = [0.5, 0.5, 0.5] +SIGLIP_PIXEL_STD = [0.5, 0.5, 0.5] +CLIP_PIXEL_MEAN = [0.48145466, 0.4578275, 0.40821073] +CLIP_PIXEL_STD = [0.26862954, 0.26130258, 0.27577711] + + +pixel_statistics = { + "clip": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "siglip": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), + "internvit": (IMAGENET_PIXEL_MEAN, IMAGENET_PIXEL_STD), + "radio": (CLIP_PIXEL_MEAN, CLIP_PIXEL_STD), + "huggingface": (SIGLIP_PIXEL_MEAN, SIGLIP_PIXEL_STD), +} + + +# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L685 +# Copyright (c) 2023 OpenGVLab. +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def find_closest_area_weighted_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + """ + Find the best number of tiles based on the aspect ratio and the area covered by the tiles. + """ + best_factor = float('-inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + factor_based_on_area_n_ratio = ( + min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6) * + min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)) + if factor_based_on_area_n_ratio > best_factor: + best_factor = factor_based_on_area_n_ratio + best_ratio = ratio + return best_ratio + + +def get_visual_transform( + img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, + vision_model_type="clip", find_closest_aspect_ratio_fn=find_closest_aspect_ratio): + pixel_mean, pixel_std = pixel_statistics[vision_model_type] + + assert not augment, "Image augmentation not implemented." + transform = build_transform(img_h, pixel_mean, pixel_std, vision_model_type) + + if use_tiling: + assert img_h == img_w, "dynamic tiling expects equal tile height and width" + imgs = dynamic_preprocess( + img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail, + find_closest_aspect_ratio_fn=find_closest_aspect_ratio_fn) + imgs = [transform(img) for img in imgs] + else: + imgs = [transform(img)] + + return imgs + + +# From https://github.com/OpenGVLab/InternVL/blob/c62fa4f7c850165d7386bdc48ac6bc5a6fab0864/internvl_chat/internvl/train/dataset.py#L702 +# Copyright (c) 2023 OpenGVLab. +def dynamic_preprocess( + image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, + find_closest_aspect_ratio_fn=find_closest_aspect_ratio): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio_fn( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +# Based on https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/clip/clip.py#L79 +# and https://github.com/OpenGVLab/InternVL/blob/aa521e6eb1df4cf153aa4118fcf13e673c055d46/internvl_chat/internvl/train/dataset.py#L276 +def build_transform(input_size, pixel_mean, pixel_std, vision_model_type): + if vision_model_type in ("siglip", "internvit", "radio", "huggingface"): + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std) + ]) + elif vision_model_type == "clip": + transform = Compose([ + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=pixel_mean, std=pixel_std), + ]) + else: + raise NotImplementedError(f"image processing not defined for vision model {vision_model_type}") + + return transform diff --git a/examples/multimodal/layer_specs.py b/examples/multimodal/layer_specs.py new file mode 100644 index 0000000000..dff1181d94 --- /dev/null +++ b/examples/multimodal/layer_specs.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import torch + +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + import apex + + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + HAVE_APEX = True + LNImpl = FusedLayerNorm +except ImportError: + import warnings + + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + warnings.warn(f'Apex is not installed. Falling back to Torch Norm') + LNImpl = WrappedTorchNorm + + +def get_layer_spec(is_vit, normalization) -> ModuleSpec: + attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal + if normalization == "LayerNorm": + norm = LNImpl + elif normalization == "RMSNorm": + if HAVE_TE: + norm = TENorm + else: + version = torch.__version__.split('.') + version_geq_2_4 = ( + int(TORCH_VERSION[0]) > 2 + or ( + int(TORCH_VERSION[0]) == 2 + and int(TORCH_VERSION[1]) >= 4 + ) + ) + assert version_geq_2_4, "Torch version >= 2.4.0 is required for RMSNorm" + if HAVE_APEX: + warnings.warn(f'Apex does not support RMSNorm. Falling back to Torch Norm') + norm = WrappedTorchNorm + else: + raise RuntimeError("unknown normalization", normalization) + + mlp = get_mlp_module_spec(use_te=False) # doesn't include norm. + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=norm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": attn_mask_type}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=norm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec: + attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal + # Padding mask is needed for e.g. Context Parallel. + if padding: + assert not is_vit, "padding_causal mask not used with ViT" + attn_mask_type = AttnMaskType.padding_causal + + mlp = get_norm_mlp_module_spec_te() + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": attn_mask_type}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ), + ) + + +def get_norm_mlp_module_spec_te() -> ModuleSpec: + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ) diff --git a/examples/multimodal/manual_prompts.json b/examples/multimodal/manual_prompts.json new file mode 100644 index 0000000000..b0dfd84801 --- /dev/null +++ b/examples/multimodal/manual_prompts.json @@ -0,0 +1,48 @@ +{ + "COMMENT": "Sources for these prompts include https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/viewer and https://huggingface.co/datasets/HuggingFaceM4/M3IT", + "Captioning": { + "raw": [ + "Can you briefly explain what you see in the image?", + "Describe what's happening in this image in one short sentence.", + "Write a short caption that accurately represents the content of this image.", + "Please generate a descriptive caption for the image provided.", + "How would you summarize the scene depicted in the picture in short?", + "Describe the image briefly.", + "Write a succinct description of the image, capturing its main components, the relationships between them, and any notable details.", + "Create a concise caption that accurately describes the main elements in the image provided.", + "Write a brief, yet comprehensive, description of the image.", + "Describe the image in a clear and concise manner.", + "For the given image, provide a one-sentence summary that captures the most important details.", + "Generate a short caption for the picture.", + "Write a short and informative description that highlights the primary subjects and actions occurring in the given image.", + "Provide a concise and informative caption for the image, focusing on the primary subjects.", + "Write a clear description of the image, make sure the key features are well covered.", + "Offer a succinct explanation of the picture presented." + ] + }, + "CaptioningPretraining": { + "raw": [ + "Generate a short caption of the image.", + "Describe the image concisely.", + "Provide a brief description of the given image." + ], + "llava": [ + "Give a brief description of image.", + "Give a brief description of the image.", + "Provide a brief description of the given image.", + "Provide a one-sentence caption for the provided image.", + "Write a terse but informative summary of the picture.", + "Describe the image concisely.", + "Generate a clear and concise summary of the photo." + ] + }, + "OCR": { + "raw": [ + "Can you read the text from image and output here?", + "Extract and document the text from the provided image.", + "Converting the text embedded in this image into a readable document.", + "Transcribe all the text you find.", + "Can you extract all visible text from the image here?" + ] + } +} diff --git a/examples/multimodal/model.py b/examples/multimodal/model.py new file mode 100644 index 0000000000..1a82b12037 --- /dev/null +++ b/examples/multimodal/model.py @@ -0,0 +1,254 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import warnings +from copy import deepcopy + +import torch +from config import get_language_model_config, get_vision_model_config, get_vision_projection_config +from layer_specs import get_layer_spec, get_layer_spec_te, get_mlp_module_spec, get_norm_mlp_module_spec_te + +from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN, LLaVAModel +from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings +from megatron.training import get_args, get_tokenizer, print_rank_0 +from megatron.training.arguments import core_transformer_config_from_args + + +def model_provider( + pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True +) -> LLaVAModel: + """Builds the model. + + Args: + pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True. + post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True. + add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder + will live on only a subset of the pipeline stages (specifically, only the first stage). + add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder + will live on only a subset of the pipeline stages (specifically, every stage after the first one). + parallel_output (bool): Enable parallel model output. + + Returns: + model: A multimodal model. + """ + args = get_args() + assert args.encoder_pipeline_model_parallel_size <= 1, "LLaVA does not support pp>1 for encoder on it's own pipeline rank" + + use_te = args.use_te + + print_rank_0('building a multimodal model ...') + + num_image_embeddings = get_num_image_embeddings( + args.img_h, + args.img_w, + args.patch_dim, + args.vision_model_type, + args.disable_vision_class_token, + 1, + args.pixel_shuffle, + args.use_tile_tags, + ) + old_seq_length = args.seq_length + args.seq_length = args.encoder_seq_length = num_image_embeddings + if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length: + warnings.warn( + f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})" + ) + + max_num_image_embeddings = (args.max_num_tiles + int(args.use_thumbnail)) * num_image_embeddings + + assert ( + args.decoder_seq_length is not None + ), "Please provide --decoder-seq-length to set the language model sequence length" + assert ( + args.decoder_seq_length > max_num_image_embeddings + ), "Language model sequence length must be greater than the maximum number of image embeddings" + if args.decoder_seq_length > args.max_position_embeddings: + args.max_position_embeddings = args.decoder_seq_length + warnings.warn( + f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length" + ) + + base_config = core_transformer_config_from_args(get_args()) + base_config.language_model_type = args.language_model_type + base_config.vision_model_type = args.vision_model_type + base_config.calculate_per_token_loss = True + + language_config = deepcopy(base_config) + language_config = get_language_model_config(language_config) + + if use_te: + # Padding mask needed for SP/CP. + padding = args.context_parallel_size > 1 and args.sequence_parallel + language_transformer_layer_spec = get_layer_spec_te( + is_vit=False, padding=padding + ) # TENorm detects LayerNorm/RMS automatically. + else: + language_transformer_layer_spec = get_layer_spec( + is_vit=False, normalization=language_config.normalization + ) + + vision_model_type = args.vision_model_type + vision_config = deepcopy(base_config) + vision_config = get_vision_model_config( + vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling + ) + if vision_model_type.startswith("huggingface"): + assert args.encoder_tensor_model_parallel_size < 2, "Huggingface vision encoders do not support --encoder-tensor-model-parallel-size > 1" + assert args.encoder_pipeline_model_parallel_size == 0, "Huggingface vision encoders do not support --encoder-pipeline-model-parallel-size > 0" + assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel" + assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1" + assert args.vision_huggingface_model_name_or_path is not None, "Providing --vision-huggingface-model-name-or-path is necessary when using huggingface vision model" + + vision_config.huggingface_model_name_or_path = args.vision_huggingface_model_name_or_path + + from transformers import AutoConfig + huggingface_config = AutoConfig.from_pretrained(vision_config.huggingface_model_name_or_path) + vision_config.hidden_size = huggingface_config.hidden_size + + vision_model_type = args.vision_model_type + if vision_model_type in ["clip", "siglip", "radio"]: + if use_te: + vision_transformer_layer_spec = get_layer_spec_te( + is_vit=True + ) # TENorm detects LayerNorm/RMS automatically. + else: + vision_transformer_layer_spec = get_layer_spec( + is_vit=True, normalization=vision_config.normalization + ) + elif vision_model_type == "internvit": + from nvlm.internvit import get_internvit_layer_spec + vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te) + elif vision_model_type.startswith("huggingface"): + vision_transformer_layer_spec = None + else: + raise RuntimeError("unsupported vision model type", vision_model_type) + + vision_projection_config = deepcopy(base_config) + + if base_config.language_model_type.startswith("huggingface"): + assert args.tensor_model_parallel_size == 1, "Huggingface models do not support --tensor-model-parallel-size > 1" + assert args.pipeline_model_parallel_size < 2, "Huggingface models do not support --pipeline-model-parallel-size > 1" + assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel" + assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1" + assert args.language_huggingface_model_name_or_path is not None, "Providing --language-huggingface-model-name-or-path is necessary when using huggingface language model" + + language_config.huggingface_model_name_or_path = args.language_huggingface_model_name_or_path + # Pass to vision projection config so can choose the correct ffn hidden size + vision_projection_config.huggingface_model_name_or_path = args.language_huggingface_model_name_or_path + + vision_projection_config = get_vision_projection_config( + vision_projection_config, language_config.hidden_size + ) + + # --encoder-pipeline-model-parallel-size 1 will enable a separate pipeline stage for the vision model. + if args.encoder_pipeline_model_parallel_size > 0: + assert ( + args.encoder_pipeline_model_parallel_size == 1 + ), "vision model and projection can only live on 1 pipeline stage." + + if args.encoder_tensor_model_parallel_size > 0: + vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size + vision_projection_config.tensor_model_parallel_size = ( + args.encoder_tensor_model_parallel_size + ) + + # Make sure vision model pipeline parallel size is not inherited from the language model pipeline parallel size. + # 0 is not a valid for the config value, hence max(1, ). + vision_config.pipeline_model_parallel_size = max(1, args.encoder_pipeline_model_parallel_size) + vision_projection_config.pipeline_model_parallel_size = vision_config.pipeline_model_parallel_size + + # Make sure the vision model does not inherit first and last pipeline num layers from the language model. + vision_config.first_pipeline_num_layers = vision_config.last_pipeline_num_layers = None + + if vision_projection_config.normalization: + vision_projection_layer_spec = get_norm_mlp_module_spec_te().submodules + else: + vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules + + # Toggle --recompute* for the vision and language model separately. + if args.recompute_vision: + if vision_config.recompute_method is not None and vision_config.recompute_granularity is not None: + vision_config.recompute_num_layers = vision_config.num_layers + else: + vision_config.recompute_granularity = None + vision_config.recompute_method = None + vision_config.recompute_num_layers = None + + vision_projection_config.recompute_granularity = None + vision_projection_config.recompute_method = None + vision_projection_config.recompute_num_layers = None + + # TODO: Vision model and projection do not use SP/CP yet. + vision_config.sequence_parallel = False + vision_config.context_parallel_size = 1 + vision_config.tp_comm_overlap = False + + vision_projection_config.sequence_parallel = False + vision_projection_config.context_parallel_size = 1 + vision_projection_config.tp_comm_overlap = False + + tokenizer = get_tokenizer() + image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + assert image_token_index is not None, f"IMAGE_TOKEN={IMAGE_TOKEN} needs to be added using the --special-tokens arg." + + tile_tags = _get_tile_tags(args, tokenizer) + + model = LLaVAModel( + language_transformer_config=language_config, + language_transformer_layer_spec=language_transformer_layer_spec, + language_vocab_size=args.padded_vocab_size, + language_max_sequence_length=args.decoder_seq_length, + vision_transformer_config=vision_config, + vision_transformer_layer_spec=vision_transformer_layer_spec, + drop_vision_class_token=args.disable_vision_class_token, + vision_projection_config=vision_projection_config, + vision_projection_layer_spec=vision_projection_layer_spec, + vision_projection_type="mlp", + allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint, + parallel_output=parallel_output, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + language_position_embedding_type=args.position_embedding_type, + language_rotary_percent=args.rotary_percent, + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder, + img_h=args.img_h, + img_w=args.img_w, + patch_dim=args.patch_dim, + language_rotary_base=args.rotary_base, + language_rope_scaling=args.use_rope_scaling, + image_token_index=image_token_index, + pixel_shuffle=args.pixel_shuffle, + tile_tags=tile_tags, + ) + + model.freeze( + freeze_language_model=args.freeze_LM, + freeze_vision_model=args.freeze_ViT, + freeze_vision_projection=False, + ) + + return model + + +def _get_tile_tags(args, tokenizer): + """Tile tags are used in NVLM to surround image tiles with text tags.""" + if not args.use_tile_tags: + return None + + # We expect the tokenized length of the tags is same. + thumbnail_tag_text = "" + if args.tokenizer_prompt_format == "nvlm-yi-34b": + thumbnail_tag_text = "" + + assert args.max_num_tiles <= 6, "Up to 6 tile tags used" + tile_tags_text = [f"" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text] + + start_idx = 0 + if tokenizer._prompt_config.has_bos: + start_idx = 1 + + # Convert to tokens [num_tiles, tile_seq_len]. + tile_tags = [tokenizer.tokenize(t)[start_idx:] for t in tile_tags_text] + + return tile_tags diff --git a/examples/multimodal/model_converter/clip_converter.py b/examples/multimodal/model_converter/clip_converter.py new file mode 100644 index 0000000000..696c810890 --- /dev/null +++ b/examples/multimodal/model_converter/clip_converter.py @@ -0,0 +1,163 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import argparse +import os + +import torch + +import clip + + +def convert(download_root, output_path, tensor_parallel_size, use_te): + device = "cuda" + + model, _ = clip.load("ViT-L/14@336px", device=device, download_root=download_root) + + state_dict = model.state_dict() + new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] + + # Indices from mapping pytorch multihead attention to megatron. + kv_channels = 64 + hidden_dim = 1024 + num_heads = 16 + indices = [] + for i in range(num_heads): + lb = i * kv_channels + ub = (i + 1) * kv_channels + indices.append(torch.arange(lb, ub, dtype=torch.int)) + indices.append(torch.arange(hidden_dim + lb, hidden_dim + ub, dtype=torch.int)) + indices.append(torch.arange(2 * hidden_dim + lb, 2 * hidden_dim + ub, dtype=torch.int)) + + indices = torch.cat(indices) + + for name, tensor in state_dict.items(): + # Skip text model. + if "visual" not in name: + continue + + # Skip final layers not used in our model. + if name == "visual.proj" or "ln_post" in name: + continue + + # Map parameter names to ones used in megatron. + new_name = "" + new_tensor = tensor + if new_tensor.dtype == torch.float16: + new_tensor = new_tensor.to(torch.float32) + + # This is used for chunking some tensors to target tensor parallel size. + chunk_dim = None + + if "class_embedding" in name: + new_name = "class_token" + # Our model uses class token that is expanded to input dimensions already. + new_tensor = new_tensor.expand(1, 1, -1) + elif "positional_embedding" in name: + new_name = "position_embeddings.weight" + elif "conv1" in name: + new_name = "conv1.weight" + elif "ln_pre.weight" in name: + new_name = "ln_pre.weight" + elif "ln_pre.bias" in name: + new_name = "ln_pre.bias" + elif "transformer.resblocks" in name: + layer_idx = name.split(".")[3] + base = f"decoder.layers.{layer_idx}" + + if "attn.in_proj_weight" in name: + new_name = f"{base}.self_attention.linear_qkv.weight" + new_tensor = new_tensor[indices] + chunk_dim = 0 + elif "attn.in_proj_bias" in name: + new_name = f"{base}.self_attention.linear_qkv.bias" + new_tensor = new_tensor[indices] + chunk_dim = 0 + elif "attn.out_proj.weight" in name: + new_name = f"{base}.self_attention.linear_proj.weight" + chunk_dim = 1 + elif "attn.out_proj.bias" in name: + new_name = f"{base}.self_attention.linear_proj.bias" + elif "ln_1.weight" in name: + new_name = f"{base}.input_layernorm.weight" + if use_te: + new_name = f"{base}.self_attention.linear_qkv.layer_norm_weight" + elif "ln_1.bias" in name: + new_name = f"{base}.input_layernorm.bias" + if use_te: + new_name = f"{base}.self_attention.linear_qkv.layer_norm_bias" + elif "mlp.c_fc.weight" in name: + new_name = f"{base}.mlp.linear_fc1.weight" + chunk_dim = 0 + elif "mlp.c_fc.bias" in name: + new_name = f"{base}.mlp.linear_fc1.bias" + chunk_dim = 0 + elif "mlp.c_proj.weight" in name: + new_name = f"{base}.mlp.linear_fc2.weight" + chunk_dim = 1 + elif "mlp.c_proj.bias" in name: + new_name = f"{base}.mlp.linear_fc2.bias" + elif "ln_2.weight" in name: + new_name = f"{base}.pre_mlp_layernorm.weight" + if use_te: + new_name = f"{base}.mlp.linear_fc1.layer_norm_weight" + elif "ln_2.bias" in name: + new_name = f"{base}.pre_mlp_layernorm.bias" + if use_te: + new_name = f"{base}.mlp.linear_fc1.layer_norm_bias" + + assert new_name != "", f"unexpected layer name {name}" + + if chunk_dim is None: + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] + else: + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) + + for i in range(tensor_parallel_size): + # chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage. + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() + + # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. + extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2") + is_extra_state_layer = any([l in new_name for l in extra_state_layers]) + if use_te and is_extra_state_layer: + layer = new_name.split(".")[-2] + if layer in extra_state_layers: + extra_state_name = ( + new_name[: new_name.rfind(".") + 1] + "_extra_state" + ) # Replace the weight name. + new_state_dicts[i]["model"][extra_state_name] = None + + for i in range(tensor_parallel_size): + output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}") + os.makedirs(output_dir_tp) + output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") + torch.save(new_state_dicts[i], output_path_tp) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=""" +Convert OpenAI CLIP VIT weights to megatron format. + + +Example usage: +python clip_converter.py --download-root /some/download/folder --output /some/output/folder --tensor-parallel-size 4 +""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--download-root", type=str, required=True, help="Download folder for OpenAI CLIP weights" + ) + parser.add_argument( + "--output", type=str, required=True, help="output directory for megatron state dict file(s)" + ) + parser.add_argument( + "--tensor-parallel-size", type=int, default=1, help="model tensor parallel size" + ) + parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine") + + args = parser.parse_args() + + convert(args.download_root, args.output, args.tensor_parallel_size, args.use_te) + + print("done.") diff --git a/examples/multimodal/model_converter/internvit_converter.py b/examples/multimodal/model_converter/internvit_converter.py new file mode 100755 index 0000000000..48404c2084 --- /dev/null +++ b/examples/multimodal/model_converter/internvit_converter.py @@ -0,0 +1,162 @@ +import argparse +import os + +import torch +from transformers import AutoModel + + +def convert(model_name, output_path, tensor_parallel_size, use_te): + """Convert InternViT HF checkpoint to mcore.""" + hf_model = AutoModel.from_pretrained( + model_name, + trust_remote_code=True + ) + + hf_state_dict = hf_model.state_dict() + new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] + + hidden_size = 3200 + num_heads = 25 + dim = 128 + + order = torch.ones(3 * hidden_size).long() + + for j in range(num_heads): + for i in range(dim): + order[i + dim*3*j] = j*dim+i + order[dim + i + dim*3*j] = j*dim+i+num_heads*dim + order[dim*2 + i + dim*3*j] = j*dim+i+num_heads*dim*2 + + for name, tensor in hf_state_dict.items(): + # Map parameter names to ones used in megatron. + new_name = "" + new_tensor = tensor + + # This is used for chunking some tensors to target tensor parallel size. + chunk_dim = None + + if "embeddings.class_embedding" in name: + new_name = "class_token" + elif "embeddings.patch_embedding.weight" in name: + new_name = "conv1.weight" + elif "embeddings.patch_embedding.bias" in name: + new_name = "conv1.bias" + elif "embeddings.position_embedding" in name: + new_name = "position_embeddings.weight" + new_tensor = new_tensor.squeeze(0) + elif "encoder.layers" in name: + layer_idx = name.split(".")[2] + + base = f"decoder.layers.{layer_idx}" + + head_dim = 128 + + if tensor_parallel_size == 1: + num_padded_heads = 25 + elif tensor_parallel_size == 8: + # Note: 25 is not divisible by 8 and we don't currently support uneven heads split with tensor parallelism. + # So we pad with dummy all-zero heads. Please use a nice even number of attention heads in your model. + num_padded_heads = 32 + else: + raise NotImplementedError("invalid tensor parallel size value:", tensor_parallel_size) + + if "ls1" in name: + new_name = f"{base}.ls1" + elif "ls2" in name: + new_name = f"{base}.ls2" + elif "attn.qkv.weight" in name: + new_name = f"{base}.self_attention.linear_qkv.weight" + num_tensors = 3 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros((padded_dim, new_tensor.shape[-1]), dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:new_tensor.shape[0], :] = new_tensor[order] + new_tensor = padded_tensor + chunk_dim = 0 + elif "attn.q_norm.weight" in name: + new_name = f"{base}.self_attention.q_layernorm.weight" + num_tensors = 1 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:new_tensor.shape[0]] = new_tensor + new_tensor = padded_tensor + chunk_dim = 0 + elif "attn.k_norm.weight" in name: + new_name = f"{base}.self_attention.k_layernorm.weight" + num_tensors = 1 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:new_tensor.shape[0]] = new_tensor + new_tensor = padded_tensor + chunk_dim = 0 + elif "attn.proj.weight" in name: + new_name = f"{base}.self_attention.linear_proj.weight" + num_tensors = 1 + padded_dim = head_dim * num_padded_heads * num_tensors + padded_tensor = torch.zeros((new_tensor.shape[0], padded_dim), dtype=new_tensor.dtype, device=new_tensor.device) + padded_tensor[:, :new_tensor.shape[-1]] = new_tensor + new_tensor = padded_tensor + chunk_dim = 1 + elif "attn.proj.bias" in name: + new_name = f"{base}.self_attention.linear_proj.bias" + elif "mlp.fc1.weight" in name: + new_name = f"{base}.mlp.linear_fc1.weight" + chunk_dim = 0 + elif "mlp.fc1.bias" in name: + new_name = f"{base}.mlp.linear_fc1.bias" + chunk_dim = 0 + elif "mlp.fc2.weight" in name: + new_name = f"{base}.mlp.linear_fc2.weight" + chunk_dim = 1 + elif "mlp.fc2.bias" in name: + new_name = f"{base}.mlp.linear_fc2.bias" + elif "norm1" in name: + new_name = f"{base}.input_layernorm.weight" + elif "norm2" in name: + new_name = f"{base}.pre_mlp_layernorm.weight" + else: + raise RuntimeError("unexpected transformer layer name", name) + else: + raise RuntimeError("unexpected layer name", name) + + assert new_name != "", f"unexpected layer name {name}" + + # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. + extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2") + is_extra_state_layer = any([l in new_name for l in extra_state_layers]) + if use_te and is_extra_state_layer: + layer = new_name.split(".")[-2] + if layer in extra_state_layers: + extra_state_name = ( + new_name[: new_name.rfind(".") + 1] + "_extra_state" + ) # Replace the weight name. + for i in range(tensor_parallel_size): + new_state_dicts[i]["model"][extra_state_name] = None + + if chunk_dim is None: + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] + else: + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) + + for i in range(tensor_parallel_size): + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() + + for i in range(tensor_parallel_size): + output_dir_tp = os.path.join(output_path, f"iter_0000001/mp_rank_0{i}") + os.makedirs(output_dir_tp, exist_ok=True) + output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") + torch.save(new_state_dicts[i], output_path_tp) + print("saved file", output_path_tp) + + print("done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="InternVIT HuggingFace to Mcore converter") + parser.add_argument("--model-name", type=str, default="OpenGVLab/InternViT-6B-448px-V1-5", help="Model name in HuggingFace") + parser.add_argument("--output-dir", type=str, required=True, help="Output directory for the mcore model.") + parser.add_argument("--use-te", action="store_true", default=True) + parser.add_argument("--tensor-parallel-size", type=int, required=True) + + args = parser.parse_args() + + convert(args.model_name, args.output_dir, args.tensor_parallel_size, args.use_te) diff --git a/examples/multimodal/model_converter/radio_converter.py b/examples/multimodal/model_converter/radio_converter.py new file mode 100644 index 0000000000..05750a66eb --- /dev/null +++ b/examples/multimodal/model_converter/radio_converter.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import argparse +import os + +import torch + +def convert(output_path, tensor_parallel_size, use_te, version): + device = "cuda" + + model = torch.hub.load('NVlabs/RADIO', 'radio_model', version=version, progress=True) + + state_dict = model.state_dict() + new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] + + # Indices from mapping pytorch multihead attention to megatron. + kv_channels = 80 + hidden_dim = 1280 + num_heads = 16 + indices = [] + for i in range(num_heads): + lb = i * kv_channels + ub = (i + 1) * kv_channels + indices.append(torch.arange(lb, ub, dtype=torch.int)) + indices.append(torch.arange(hidden_dim + lb, hidden_dim + ub, dtype=torch.int)) + indices.append(torch.arange(2 * hidden_dim + lb, 2 * hidden_dim + ub, dtype=torch.int)) + + indices = torch.cat(indices) + + for name, tensor in state_dict.items(): + # Map parameter names to ones used in megatron. + new_name = "" + new_tensor = tensor + if new_tensor.dtype == torch.float16: + new_tensor = new_tensor.to(torch.float32) + + # This is used for chunking some tensors to target tensor parallel size. + chunk_dim = None + + if "summary_idxs" in name: + continue + elif "patch_generator" in name: + if "embedder" in name: + new_name = "embedder.weight" + chunk_dim = 0 + elif "cls_token" in name: + new_name = "class_token" + elif "pos_embed" in name: + new_name = "position_embeddings" + elif "input_conditioner" in name: + continue + elif "blocks" in name: + layer_idx = name.split(".")[2] + base = f"decoder.layers.{layer_idx}" + + if "attn.qkv.weight" in name: + new_name = f"{base}.self_attention.linear_qkv.weight" + new_tensor = new_tensor[indices] + chunk_dim = 0 + elif "attn.qkv.bias" in name: + new_name = f"{base}.self_attention.linear_qkv.bias" + new_tensor = new_tensor[indices] + chunk_dim = 0 + elif "attn.proj.weight" in name: + new_name = f"{base}.self_attention.linear_proj.weight" + chunk_dim = 1 + elif "attn.proj.bias" in name: + new_name = f"{base}.self_attention.linear_proj.bias" + elif "norm1.weight" in name: + new_name = f"{base}.input_layernorm.weight" + if use_te: + new_name = f"{base}.self_attention.linear_qkv.layer_norm_weight" + elif "norm1.bias" in name: + new_name = f"{base}.input_layernorm.bias" + if use_te: + new_name = f"{base}.self_attention.linear_qkv.layer_norm_bias" + elif "mlp.fc1.weight" in name: + new_name = f"{base}.mlp.linear_fc1.weight" + chunk_dim = 0 + elif "mlp.fc1.bias" in name: + new_name = f"{base}.mlp.linear_fc1.bias" + chunk_dim = 0 + elif "mlp.fc2.weight" in name: + new_name = f"{base}.mlp.linear_fc2.weight" + chunk_dim = 1 + elif "mlp.fc2.bias" in name: + new_name = f"{base}.mlp.linear_fc2.bias" + elif "norm2.weight" in name: + new_name = f"{base}.pre_mlp_layernorm.weight" + if use_te: + new_name = f"{base}.mlp.linear_fc1.layer_norm_weight" + elif "norm2.bias" in name: + new_name = f"{base}.pre_mlp_layernorm.bias" + if use_te: + new_name = f"{base}.mlp.linear_fc1.layer_norm_bias" + + assert new_name != "", f"unexpected layer name {name}" + + if chunk_dim is None: + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] + else: + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) + + for i in range(tensor_parallel_size): + # chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage. + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() + + # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. + extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2") + is_extra_state_layer = any([l in new_name for l in extra_state_layers]) + if use_te and is_extra_state_layer: + layer = new_name.split(".")[-2] + if layer in extra_state_layers: + extra_state_name = ( + new_name[: new_name.rfind(".") + 1] + "_extra_state" + ) # Replace the weight name. + new_state_dicts[i]["model"][extra_state_name] = None + + for i in range(tensor_parallel_size): + output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}") + os.makedirs(output_dir_tp) + output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") + torch.save(new_state_dicts[i], output_path_tp) + with open(os.path.join(output_path, "latest_checkpointed_iteration.txt"), "w") as f: + f.write("1") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=""" +Convert RADIO weights to megatron format. + + +Example usage: +python radio_converter.py --output /some/output/folder --tensor-parallel-size 4 +""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--output", type=str, required=True, help="output directory for megatron state dict file(s)" + ) + parser.add_argument( + "--tensor-parallel-size", type=int, default=1, help="model tensor parallel size" + ) + parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine") + parser.add_argument("--version", type=str, default="radio_v2.5-h", help="Version of radio to load for conversion") + + args = parser.parse_args() + + convert(args.output, args.tensor_parallel_size, args.use_te, args.version) + + print("done.") diff --git a/examples/multimodal/model_converter/siglip_converter.py b/examples/multimodal/model_converter/siglip_converter.py new file mode 100644 index 0000000000..666cda15eb --- /dev/null +++ b/examples/multimodal/model_converter/siglip_converter.py @@ -0,0 +1,154 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import argparse +import os +from transformers import PaliGemmaForConditionalGeneration +import torch + + +def convert(output_path, tensor_parallel_size, use_te): + device = "cuda" + + model_id = "google/paligemma-3b-pt-448" + model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval() + + model = model.to(device) + + print(model.config) + for name, tensor in model.state_dict().items(): + if "vision_model" not in name: + continue + shape_str = "(" + ", ".join([str(x) for x in tensor.shape]) + ")" + print(f"{name:<75} {shape_str:>20}") + + state_dict = model.state_dict() + new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] + + def add_chunck_tensor(new_tensor, new_name, chunk_dim=None): + if chunk_dim is None: + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] + else: + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) + + for i in range(tensor_parallel_size): + # chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage. + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() + + # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. + extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2") + is_extra_state_layer = any([l in new_name for l in extra_state_layers]) + if use_te and is_extra_state_layer: + layer = new_name.split(".")[-2] + if layer in extra_state_layers: + extra_state_name = ( + new_name[: new_name.rfind(".") + 1] + "_extra_state" + ) # Replace the weight name. + new_state_dicts[i]["model"][extra_state_name] = None + + for name, tensor in state_dict.items(): + if tensor.dtype == torch.float16: + state_dict[name] = tensor.to(torch.float32) + + add_chunck_tensor( + state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"], + "position_embeddings.weight") + add_chunck_tensor( + state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"], + "conv1.weight") + add_chunck_tensor( + state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"], + "conv1.bias") + + head_dim = 72 + num_head = 16 + for layer_idx in range(27): + origin_base = f"vision_tower.vision_model.encoder.layers.{layer_idx}" + target_base = f"decoder.layers.{layer_idx}" + + for param_type in ["weight", "bias"]: + # QKV + q_proj_params = state_dict[f"{origin_base}.self_attn.q_proj.{param_type}"] + k_proj_params = state_dict[f"{origin_base}.self_attn.k_proj.{param_type}"] + v_proj_params = state_dict[f"{origin_base}.self_attn.v_proj.{param_type}"] + # Do some tensor manipulation because megatron expect one tensor + # projection for the QKV in the order + # [(Q1, K1, V1), (Q2, K2, V2), ...] where Qi is the query of the + # i-th head with dimension num_head. + new_tensor = torch.concatenate([ + q_proj_params.view(num_head, head_dim, -1), + k_proj_params.view(num_head, head_dim, -1), + v_proj_params.view(num_head, head_dim, -1)], axis=1).view( + 3*head_dim*num_head, -1) + if param_type == "bias": + new_tensor = new_tensor[:, 0] + new_name = f"{target_base}.self_attention.linear_qkv.{param_type}" + add_chunck_tensor(new_tensor, new_name, chunk_dim=0) + # linear_proj + add_chunck_tensor( + state_dict[f"{origin_base}.self_attn.out_proj.{param_type}"], + f"{target_base}.self_attention.linear_proj.{param_type}", + chunk_dim=1 if param_type == "weight" else None) + # layer_norm + new_name = f"{target_base}.input_layernorm.{param_type}" + if use_te: + new_name = f"{target_base}.self_attention.linear_qkv.layer_norm_{param_type}" + add_chunck_tensor( + state_dict[f"{origin_base}.layer_norm1.{param_type}"], + new_name) + # FC 1 + add_chunck_tensor( + state_dict[f"{origin_base}.mlp.fc1.{param_type}"], + f"{target_base}.mlp.linear_fc1.{param_type}", + chunk_dim=0) + # FC 2 + add_chunck_tensor( + state_dict[f"{origin_base}.mlp.fc2.{param_type}"], + f"{target_base}.mlp.linear_fc2.{param_type}", + chunk_dim=1 if param_type=="weight" else None) + # layer_norm + new_name = f"{target_base}.pre_mlp_layernorm.{param_type}" + if use_te: + new_name = f"{target_base}.mlp.linear_fc1.layer_norm_{param_type}" + add_chunck_tensor( + state_dict[f"{origin_base}.layer_norm2.{param_type}"], + new_name) + + add_chunck_tensor( + state_dict["vision_tower.vision_model.post_layernorm.weight"], + "ln_post.weight") + add_chunck_tensor( + state_dict["vision_tower.vision_model.post_layernorm.bias"], + "ln_post.bias") + + for i in range(tensor_parallel_size): + output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}") + os.makedirs(output_dir_tp) + output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") + torch.save(new_state_dicts[i], output_path_tp) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=""" +Convert SigLIP weights to megatron format. + + +Example usage: +python siglip_converter.py --tensor-parallel-size 4 --output google_paligemma_3b_pt_44_mcore_tp_4 --use-te + +examples/multimodal/combine_mistral_clip.sh Mistral-7B-Instruct-v0.3-mcore-tp4 google_paligemma_3b_pt_44_mcore_tp_4 mistral_7b_instruct_v0p3_google_paligemma_3b_pt_44_mcore_tp_4 +""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--output", type=str, required=True, help="output directory for megatron state dict file(s)" + ) + parser.add_argument( + "--tensor-parallel-size", type=int, default=1, help="model tensor parallel size" + ) + parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine") + + args = parser.parse_args() + + convert(args.output, args.tensor_parallel_size, args.use_te) + + print("done.") diff --git a/examples/multimodal/model_converter/vision_model_tester.py b/examples/multimodal/model_converter/vision_model_tester.py new file mode 100644 index 0000000000..ef36dd5f9e --- /dev/null +++ b/examples/multimodal/model_converter/vision_model_tester.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import argparse +import os +import sys + +# Add megatron and the multimodal example to the path. +sys.path.append( + os.path.abspath( + os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir, os.path.pardir) + ) +) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + +import torch +from transformers import AutoModel + +from examples.multimodal.model import model_provider +from examples.multimodal.multimodal_args import add_multimodal_extra_args +from megatron.training import get_model +from megatron.training.checkpointing import load_checkpoint +from megatron.training.initialize import initialize_megatron + + +def run_mcore_vision(model_path): + """Run mcore vision model.""" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + # Megatron has some mandatory flags. + sys.argv = [ + "ignore_me.py", + "--micro-batch-size=1", + "--num-layers=2", + "--vision-model-type=internvit", + "--language-model-type=mistral_7b", + "--tokenizer-prompt-format=mistral", + "--tokenizer-type=MultimodalTokenizer", + "--tokenizer-model=mistralai/Mistral-7B-Instruct-v0.3", + "--vocab-size=1024", + "--hidden-size=64", + "--num-attention-heads=8", + "--seq-length=1024", + "--decoder-seq-length=2048", + "--max-position-embeddings=2048", + "--bf16", + "--img-h=448", + "--img-w=448", + "--patch-dim=14", + "--tensor-model-parallel-size=8", + "--use-te", + f"--pretrained-checkpoint={model_path}", + ] + + initialize_megatron(extra_args_provider=add_multimodal_extra_args) + + def wrapped_model_provider(pre_process, post_process): + return model_provider(pre_process, post_process, parallel_output=False) + + # Set up model and load checkpoint. + model = get_model(wrapped_model_provider, wrap_with_ddp=False) + + vision_model = model[0].module.vision_model + + load_checkpoint([vision_model], None, None) + + vision_model.eval() + + images = torch.ones((1, 3, 448, 448), dtype=torch.bfloat16, device="cuda") + + output = vision_model(images) + + return output + + +def run_hf_vision(model_name): + """Run HF vision model.""" + model = ( + AutoModel.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True) + .cuda() + .eval() + ) + + images = torch.ones((1, 3, 448, 448), dtype=torch.bfloat16, device="cuda") + + outputs = model(images, return_dict=True) + + return outputs + + +def main(mcore_model, hf_model): + """Compare vision model outputs between mcore and HF given the same fixed input.""" + mcore = run_mcore_vision(mcore_model) + + if torch.distributed.get_rank() == 0: + hf = run_hf_vision(hf_model) + hf = hf["last_hidden_state"] + + # Compare logits. Due to different attention implementations and other details, + # there will be numerical differences. + diff = (mcore - hf).abs() + mean_diff = diff.mean().item() + max_diff = diff.max().item() + print(f"mean diff {mean_diff}, max diff {max_diff}") + assert mean_diff < 0.1, "mean output difference is greater than expected" + assert max_diff < 50, "max output difference is greater than expected" + + print("lgtm") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Check mcore vision model output vs. HF numerically.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--mcore-model", type=str, required=True, help="directory for mcore model weights" + ) + parser.add_argument("--hf-model", type=str, required=True, help="Model name in HF") + + args = parser.parse_args() + + main(args.mcore_model, args.hf_model) diff --git a/examples/multimodal/multimodal_args.py b/examples/multimodal/multimodal_args.py new file mode 100644 index 0000000000..bef6bb9a48 --- /dev/null +++ b/examples/multimodal/multimodal_args.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN + + +def add_multimodal_extra_args(parser): + """Extra arguments.""" + group = parser.add_argument_group(title='multimodal arguments') + group.add_argument('--dataset-config', type=str, default=None) + group.add_argument("--prompt-path", type=str, default=None) + group.add_argument('--freeze-LM', action='store_true', default=False) + group.add_argument('--freeze-ViT', action='store_true', default=False) + group.add_argument('--language-model-type', type=str, required=True) + group.add_argument('--language-huggingface-model-name-or-path', type=str) + group.add_argument('--vision-model-type', type=str, default="clip") + group.add_argument('--vision-huggingface-model-name-or-path', type=str) + group.add_argument("--disable-vision-class-token", action="store_true", default=False) + group.add_argument( + "--allow-missing-vision-projection-checkpoint", action="store_true", default=False + ) + group.add_argument("--use-te", action="store_true", default=False) + group.add_argument( + "--dataloader-save", type=str, default=None, help="Energon dataloader state save path" + ) + group.add_argument( + "--use-tiling", action="store_true", default=False, help="Use input image tiling" + ) + group.add_argument("--max-num-tiles", type=int, default=1, help="Maximum number of image tiles") + group.add_argument( + "--use-thumbnail", action="store_true", default=False, help="Add image thumbnail as a tile" + ) + group.add_argument( + "--dataloader-seq-length", + type=int, + help="Make dataloader to produce sequences of specific length.", + ) + group.add_argument( + "--num-frames", + type=int, + default=1, + help="Number of frames to regularly sample from the video as input to the model.", + ) + group.add_argument( + "--online-evaluation-config", type=str, help="Config file for online evaluation." + ) + group.add_argument( + "--special-tokens", + nargs="*", + default=[IMAGE_TOKEN], + help="Special tokens used in the multimodal model", + ) + group.add_argument( + "--tokenizer-prompt-format", + type=str, + choices=["mistral", "llama3", "llama3p1", "chatml", "nvlm-yi-34b", "qwen2p0", "qwen2p5"], + required=True, + help="Prompt format to use with the tokenizer.", + ) + group.add_argument("--pixel-shuffle", action="store_true", default=False) + group.add_argument( + "--image-tag-type", + type=str, + choices=["nvlm", "internvl", ""], + default="", # Default: Image tag not used. + help="Surround image tokens with tags.", + ) + group.add_argument("--use-tile-tags", action="store_true", default=False, help="Use tile tags") + group.add_argument( + "--packing-buffer-size", + type=int, + default=None, # Packing is disabled by default. + help="Enable sample packing by setting the buffer size to > 0", + ) + group.add_argument( + "--packing-seq-length", type=int, default=0, help="Packing sequence length. Must be > 0 if using packing." + ) + group.add_argument( + "--recompute-vision", action="store_true", default=False, help="Enable activation checkpointing in the vision model" + ) + group.add_argument( + "--use-loss-scaling", action="store_true", default=False, help="Scale loss based on conversation turn length (in tokens)." + ) + group.add_argument( + "--use-area-weighted-aspect-ratio", action="store_true", default=False, + help=( + "When --use-tiling is True, find the aspect ratio to use based on the original ", + "image aspect ratio and the area covered by the tiles.") + ) + + return parser diff --git a/examples/multimodal/nvlm/README.md b/examples/multimodal/nvlm/README.md new file mode 100644 index 0000000000..bb576bb403 --- /dev/null +++ b/examples/multimodal/nvlm/README.md @@ -0,0 +1,107 @@ +NVLM +==== + +Please refer to the [NVLM paper](https://arxiv.org/pdf/2409.11402) for details. + +*NOTE: VLMs in Megatron are under active development and are expected to change.* + +# Checkpoints + +NVLM 1.0 model weights are publicly available in HuggingFace and Megatron format. + +- NVLM-1.0-D 72B [HuggingFace version](https://huggingface.co/nvidia/NVLM-D-72B) +- NVLM-1.0-D 72B [Megatron-Core version](https://huggingface.co/nvidia/NVLM-D-72B-mcore) + +# Setup + +## Docker image + +Please use `examples/multimodal/Dockerfile`. + +## Dataset preparation + +Please refer to Tables 4 and 6 in the [NVLM paper](https://arxiv.org/pdf/2409.11402) for full list of pretrain and SFT datasets. +Please refer to https://nvidia.github.io/Megatron-Energon/data_prep.html on preparing datasets in the Megatron Energon format. + +## Model conversion + +### Vision model + +NVLM 1.0 models use [OpenGVLab/InternViT-6B-448px-V1-5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-5) from HuggingFace. +Please download it and run the following command to convert it to Megatron format. +``` +python examples/multimodal/model_converter/internvit_converter.py --output-dir --use-te --tensor-parallel-size 8 +``` + +### 34B Language model + +NVLM 1.0 34B starts from [NousResearch/Nous-Hermes-2-Yi-34B](https://huggingface.co/NousResearch/Nous-Hermes-2-Yi-34B) from HuggingFace. +Please download it and run the following command to convert it to Megatron format. +``` +python tools/checkpoint/convert.py --bf16 --model-type GPT --loader llama_mistral --saver mcore --target-tensor-parallel-size 8 --checkpoint-type hf \ + --load-dir --save-dir --tokenizer-model \ + --saver-transformer-impl transformer_engine --model-size yi-34B --make-vocab-size-divisible-by 1 +``` + +### 72B Language model + +NVLM 1.0 72B starts from [Qwen/Qwen2-72B-Instruct](https://huggingface.co/Qwen/Qwen2-72B-Instruct) from HuggingFace. +Please download it and run the following command to convert it to Megatron format. +``` +python tools/checkpoint/convert.py --bf16 --model-type GPT --loader llama_mistral --saver mcore --target-tensor-parallel-size 8 --checkpoint-type hf \ + --load-dir --save-dir --tokenizer-model \ + --saver-transformer-impl transformer_engine --model-size qwen2.5-72Bf +``` + +### Combined checkpoint + +Combine the vision model checkpoint from [InternVit](#internvit) with the [34B](#34b-language-model) or [72B](#72b-language-model) language model by running: +``` +examples/multimodal/combine_lm_vision_checkpoints.sh nvlm +``` + +# Training + +## 34B + +1. Pretraining: please run `examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh`. Please use the InternViT + 34B [combined checkpoint](#combined-checkpoint) and tokenizer from HuggingFace. +2. SFT: please run `examples/multimodal/nvlm/sft_34b_internvit.sh` using the checkpoint from 1. + +## 72B + +1. Pretraining: please run `examples/multimodal/nvlm/pretrain_qwen20_72b_internvit_6b.sh`. Please use the InternViT + 72B [combined checkpoint](#combined-checkpoint) and tokenizer from HuggingFace. +2. Convert the pretraining checkpoint from 1. to have pipeline parallel size = 4 for SFT. Please run +``` +examples/multimodal/nvlm/pp_checkpoint_converter.py --input \ +--input-pipeline-parallel 1 --output --output-pipeline-parallel 4 \ +--tensor-parallel 8 +``` +3. SFT: please run `examples/multimodal/nvlm/sft_qwen20_72b_internvit_6b.sh` using the checkpoint from 2. +4. To convert the checkpoint with pipeline parallel size = 4 back to 1 for evaluation, please run +``` +examples/multimodal/nvlm/pp_checkpoint_converter.py --input \ +--input-pipeline-parallel 4 --output --output-pipeline-parallel 1 \ +--tensor-parallel 8 +``` + +# Evaluation + +Run the text generation script. +- 34B +``` +examples/multimodal/nvlm/run_text_generation_yi_34b_internvit_6b.sh --input-image-path /path/to/input/images --output-path /some/output/directory \ + --model-path /path/to/model.pt --gt-path /path/to/groundtruth/file --task generation-task-name --use-tiling +``` +- 72B +``` +examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh --input-image-path /path/to/input/images --output-path /some/output/directory \ + --model-path /path/to/model.pt --gt-path /path/to/groundtruth/file --task generation-task-name --use-tiling +``` + +where `--task generation-task-name` is the name of the evaluation benchmark such as `captioning`, `MMMU` or `TextVQA`. + +Then, run one of the evaluation scripts from `examples/multimodal`. For example + +``` +python examples/multimodal/evaluate_mmmu.py --input-path /output/directory/from/generation +``` diff --git a/examples/multimodal/nvlm/internvit.py b/examples/multimodal/nvlm/internvit.py new file mode 100644 index 0000000000..f00d2dd5f3 --- /dev/null +++ b/examples/multimodal/nvlm/internvit.py @@ -0,0 +1,279 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""" +NOTE: NVLM uses InternViT with tensor parallel (TP) size = 8. +Since InternViT has 25 attention heads and Megatron currently requires the number of attention heads +to be divisible by the TP size, we add 7 dummy zero attention heads to have 32 attention heads. + +This workaround requires some changes to how we compute RMSNorm, Attention etc. + +Additionally, InternViT introduces some unique features like Layer Scaling. + +Those code changes are gathered here. +""" +from functools import partial + +import torch + +from megatron.core.utils import divide +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TERowParallelLinear, +) +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint + + +class InternViTRMSNorm(MegatronModule): + + def __init__( + self, + config, + hidden_size: int, + eps: float = 1e-6, + sequence_parallel: bool = False, + compute_var: bool = False, + ): + """Custom RMSNorm for InternViT. + + Args: + config (TransformerConfig): Config. + hidden_size (int): Input hidden size. + eps (float): epsilon to use for the norm, default to 1e-6 + sequence_parallel (bool): Set to true if sequence parallelism is being used, + this marks the weights as needing to be allreduced. + compute_var (bool): Indicator to compute statistic manually. + """ + super().__init__(config=config) + self.config = config + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self._compute_var = compute_var + + assert not sequence_parallel, "Sequence parallelism is not supported with InternViT." + + setattr(self.weight, 'sequence_parallel', sequence_parallel) + + def _norm(self, x, var): + if var is None: + var = x.pow(2).mean(-1, keepdim=True) + + return x * torch.rsqrt(var + self.eps) + + def forward(self, x): + """Run RMSNorm with an option to compute custom statistic.""" + var = None + if self._compute_var: + unpadded_hidden_size = self.config.hidden_size # 3200 + max_dim = x.shape[-1] # 128 + + x = x.reshape(x.size(0), x.size(1), -1) + var = self._gather_var(x.float().pow(2), max_dim) / unpadded_hidden_size + + output = self._norm(x.float(), var).type_as(x) + output = output * self.weight + + if self._compute_var: + output = output.reshape(output.size(0), output.size(1), -1, max_dim) + + return output + + def _gather_var(self, input_, max_dim): + """Compute statistic across the non-dummy heads.""" + world_size = get_tensor_model_parallel_world_size() + + # Size and dimension. + last_dim = input_.dim() - 1 + rank = get_tensor_model_parallel_rank() + + num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) + valid_ranks = 24 // num_attention_heads_per_partition + + residual_heads = 25 % num_attention_heads_per_partition + if residual_heads == 0: + residual_heads = num_attention_heads_per_partition + max_dim = max_dim * residual_heads + + if rank < valid_ranks: # Ranks without any dummy attention heads. + var = input_.sum(-1, keepdim=True) + elif rank == valid_ranks: # The only rank which may contain 'residual_heads' dummy attention heads. + var = input_[..., :max_dim].sum(-1, keepdim=True) + else: + var = input_.sum(-1, keepdim=True) * 0.0 # All heads in these ranks are dummy heads: Zero-out. + + tensor_list = [torch.empty_like(var) for _ in range(world_size)] + tensor_list[rank] = var + torch.distributed.all_gather(tensor_list, var, group=get_tensor_model_parallel_group()) + + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output.sum(-1, keepdim=True) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata={}): + + # in InternVitSelfAttention the q_layernorm and k_layernorm weights + # are tensor-parallel so must be converted to sharded tensors + if 'q_layernorm' in prefix or 'k_layernorm' in prefix: + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {'weight': 0}, sharded_offsets + ) + else: + return super().sharded_state_dict(prefix, sharded_offsets, metadata) + + +def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ), + ) + + +# Handle InternViT's layer scaling. +def _bias_dropout_add_func_internvit(ls, x_with_bias, residual, prob, training): + x, bias = x_with_bias # unpack + residual = residual if residual.dtype == x.dtype else residual.to(x.dtype) + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out * ls + return out + else: + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out * ls + return out + + +def bias_dropout_add_unfused_internvit(ls, training): + """Bias-dropout-add as in Megatron but with added LayerScaling handling.""" + + def _bias_dropout_add(x_with_bias, residual, prob): + return _bias_dropout_add_func_internvit(ls, x_with_bias, residual, prob, training) + + return _bias_dropout_add + + +def get_bias_dropout_add_internvit(ls, training, fused): + """Bias-dropout-add as in Megatron but with added LayerScaling handling.""" + assert not fused, "Fused bias-dropout-add not implemented for InternViT." + return bias_dropout_add_unfused_internvit(ls, training) + + +# Add InternViT specialties to our default TransformerLayer. +class InternViTTransformerLayer(TransformerLayer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ls1 = torch.nn.Parameter(torch.ones(self.config.hidden_size)) + self.ls2 = torch.nn.Parameter(torch.ones(self.config.hidden_size)) + + self.self_attn_bda = partial(self.self_attn_bda, self.ls1) + self.mlp_bda = partial(self.mlp_bda, self.ls2) + + +# Override a few things that are special in InternViT and not supported by the SelfAttention class. +class InternViTSelfAttention(SelfAttention): + def __init__( + self, config: TransformerConfig, submodules: SelfAttentionSubmodules, *args, **kwargs + ): + super().__init__(config=config, submodules=submodules, *args, **kwargs) + + # Need to override linear_qkv, q_layernorm and k_layernorm. + qkv_bias = False + + self.linear_qkv = build_module( + submodules.linear_qkv, + self.config.hidden_size, + self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + ) + + qk_layernorm_hidden_size = ( + self.hidden_size_per_attention_head * self.num_attention_heads_per_partition + ) # 512 for internvit + + self.q_layernorm = build_module( + submodules.q_layernorm, + hidden_size=qk_layernorm_hidden_size, + config=self.config, + eps=self.config.layernorm_epsilon, + compute_var=True, + ) + + self.k_layernorm = build_module( + submodules.k_layernorm, + hidden_size=qk_layernorm_hidden_size, + config=self.config, + eps=self.config.layernorm_epsilon, + compute_var=True, + ) + + +class InternViTTEDotProductAttention(TEDotProductAttention): + """Adjusted Attention for InternViT""" + + def forward(self, *args, **kwargs): + """Regular TEDotProductAttention + zero-out dummy attention heads.""" + out = super().forward(*args, **kwargs) + + # This makes sure the dummy attention heads are zeroed out. + mask = torch.ones_like(out, dtype=out.dtype, device=out.device) + rank = get_tensor_model_parallel_rank() + max_dim = out.shape[-1] # 128 + valid_ranks = 6 + + if rank == valid_ranks: + mask[..., max_dim:] *= 0.0 + elif rank > valid_ranks: + mask *= 0.0 + out *= mask + + return out + + +def get_internvit_layer_spec(use_te) -> ModuleSpec: + mlp = get_mlp_module_spec(use_te) # no norm + + return ModuleSpec( + module=InternViTTransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=InternViTRMSNorm, + self_attention=ModuleSpec( + module=InternViTSelfAttention, + params={"attn_mask_type": AttnMaskType.no_mask}, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear if use_te else ColumnParallelLinear, + core_attention=TEDotProductAttention if use_te else DotProductAttention, + linear_proj=TERowParallelLinear if use_te else RowParallelLinear, + q_layernorm=InternViTRMSNorm, + k_layernorm=InternViTRMSNorm, + ), + ), + self_attn_bda=get_bias_dropout_add_internvit, + pre_mlp_layernorm=InternViTRMSNorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add_internvit, + ), + ) diff --git a/examples/multimodal/nvlm/nvlm_prompts.json b/examples/multimodal/nvlm/nvlm_prompts.json new file mode 100644 index 0000000000..ab36adc765 --- /dev/null +++ b/examples/multimodal/nvlm/nvlm_prompts.json @@ -0,0 +1,165 @@ +{ + "COMMENT": "Mixture of our own custom prompts and some prompts from https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/viewer and https://huggingface.co/datasets/HuggingFaceM4/M3IT", + "Captioning": { + "raw": [ + "Can you briefly explain what you see in the image?", + "Describe what's happening in this image in one short sentence.", + "Write a short caption that accurately represents the content of this image.", + "Please generate a descriptive caption for the image provided.", + "How would you summarize the scene depicted in the picture in short?", + "Describe the image briefly.", + "Write a succinct description of the image, capturing its main components, the relationships between them, and any notable details.", + "Create a concise caption that accurately describes the main elements in the image provided.", + "Write a brief, yet comprehensive, description of the image.", + "Describe the image in a clear and concise manner.", + "For the given image, provide a one-sentence summary that captures the most important details.", + "Generate a short caption for the picture.", + "Write a short and informative description that highlights the primary subjects and actions occurring in the given image.", + "Provide a concise and informative caption for the image, focusing on the primary subjects.", + "Write a clear description of the image, make sure the key features are well covered.", + "Offer a succinct explanation of the picture presented." + ] + }, + "CaptioningPretraining": { + "raw": [ + "Give a brief description of image.", + "Give a brief description of the image.", + "Provide a brief description of the given image.", + "Provide a one-sentence caption for the provided image.", + "Write a terse but informative summary of the picture.", + "Describe the image concisely.", + "Generate a clear and concise summary of the photo." + ] + }, + "CaptioningSFT": { + "raw": [ + "Give a brief description of the image.", + "Give a short and clear explanation of the subsequent image.", + "Present a compact description of the photo's key features.", + "Provide a brief description of the given image.", + "Provide a one-sentence caption for the provided image.", + "Render a clear and concise summary of the photo.", + "Share a concise interpretation of the image provided.", + "Summarize the visual content of the image.", + "Write a terse but informative summary of the picture.", + "Describe the image concisely." + ] + }, + "VQAPretraining": { + "raw": [ + "Question: {} Short answer:", + "Question: {} Answer:" + ] + }, + "VQASFT": { + "raw": [ + "{}", + "{}\nAnswer the question using a single word or phrase." + ], + "docvqa": [ + "{}", + "{}\nAnswer this question using the text in the image directly." + ] + }, + "DocPretraining": { + "raw": [ + "Retrieve the text from the given pdf image.", + "Extract the text from the provided document.", + "Transcribe the text displayed in the image." + ], + "ocr_multi": [ + "Apply grounded Optical Character Recognition (OCR) to the provided image.", + "Extract all texts and their bounding boxes from the given image using grounded OCR.", + "Extract and transcribe all visible text from the provided image, ensuring accurate spatial recognition.", + "Conduct a detailed optical character recognition analysis on this image, maintaining the text's original layout and positioning.", + "Execute a thorough text recognition procedure on this visual input, ensuring that the spatial arrangement of the text is accurately represented.", + "Perform an in-depth OCR scan of the image, capturing both the content and contextual positioning of all textual information.", + "OCR with grounding:" + ], + "md": [ + "Extract the text from the given image and format it in Markdown.", + "Convert the text from the provided image into Markdown format.", + "Transform the text from the given image into Markdown syntax.", + "Extract and convert the text from the image to Markdown.", + "Retrieve the text from the image and present it in Markdown format." + ], + "grounded_ocr": [ + "{}. Text:", + "Recognize the text in this region: {}.", + "Identify the text in this area: {}.", + "Detect the text within this section: {}." + ], + "referring_grounding": [ + "Region of \"{}\" is:", + "Locate the text \"{}\" in the image.", + "Identify the text \"{}\" in the image and provide the coordinates." + ] + }, + "CaptioningDetailed": { + "raw": [ + "Create a comprehensive paragraph that captures the essence of the image while weaving a cohesive narrative around its elements.", + "Compose a paragraph that thoroughly describes the image's content, providing context and connections between different aspects of the scene.", + "Provide a detailed, paragraph-length description of the image that paints a vivid picture and tells a coherent story.", + "Write a rich and engaging paragraph that delves into the image's components, describing not only what is seen but also how the elements relate to one another.", + "Give a well-rounded, paragraph-length explanation of the image, describing the scene and its components while forming a complete and engaging narrative.", + "Produce a paragraph that not only describes the individual elements in the image but also weaves them together to form a cohesive, connected account.", + "Construct a paragraph that captures the image's details and context, offering a more in-depth and engaging story than a simple caption.", + "Compose a descriptive paragraph that brings the image to life through detailed storytelling, connecting the various visual elements into a unified narrative.", + "Create a paragraph that provides an extensive and interconnected description of the image, ensuring that the narrative is both detailed and cohesive.", + "Write a compelling and detailed paragraph that delves into the image's components, linking them together to create a unified and engaging story." + ] + }, + "OCR": { + "raw": [ + "Can you read the text from image and output here?", + "Extract and document the text from the provided image.", + "Converting the text embedded in this image into a readable document.", + "Transcribe all the text you find.", + "Can you extract all visible text from the image here?" + ], + "markdown": [ + "Can you extract all visible text from the provided image?", + "Converting the text embedded in this image into a readable markdown document.", + "Can you read the text in the document as markdown?", + "Transcribe the document as markdown.", + "Extract and document the text from the provided image." + ], + "table_markdown": [ + "Can you extract all visible text from the provided table?", + "Can you read the text in the provided table as markdown?", + "Transcribe the table as markdown.", + "Extract and document the text from the provided table image." + ], + "plain": [ + "Transcribe the document as plain text.", + "Extract and document the text from the provided image.", + "Converting the text embedded in this image into a readable document.", + "Transcribe all the text you find.", + "Can you extract all visible text from the image here?" + ], + "bbox_plain": [ + "Transcribe the document as plain text along with bounding boxes.", + "Extract and document the text from the provided image along with bounding boxes.", + "Converting the text embedded in this image into a readable documen along with bounding boxes.", + "Can you extract all visible text with bounding boxes from the image here?" + ] + }, + "VQA": { + "raw": [ + "Given the image, answer the following question with few words.", + "Answer the following question: ", + "What is the answer to this question?", + "Write the answer: ", + "Please answer this question: " + ] + }, + "Embedded": { + "raw": [ + "Given the image, answer the following question with few words.", + "Answer the following question: ", + "What is the answer to this question?", + "Write the answer: ", + "Please answer this question: " + ] + } +} diff --git a/examples/multimodal/nvlm/pp_checkpoint_converter.py b/examples/multimodal/nvlm/pp_checkpoint_converter.py new file mode 100644 index 0000000000..7e99d650b1 --- /dev/null +++ b/examples/multimodal/nvlm/pp_checkpoint_converter.py @@ -0,0 +1,180 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import argparse +import os +import sys + +import torch + +# Add megatron to the path. +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir, os.path.pardir)) +) + + +def split(input_dir, base_output_dir, input_pp, output_pp, num_tp, num_layers_per_pp_rank): + """Split pipeline parallel size = 1 checkpoint to pipeline parallel size N.""" + for tp in range(num_tp): + path = os.path.join(input_dir, f"mp_rank_0{tp}", "model_optim_rng.pt") + sd = torch.load(path) + + if num_layers_per_pp_rank is None: + num_layers = sd["args"].num_layers + assert num_layers % output_pp == 0, "specify --num-layers-per-pp-rank for an uneven split" + num_layers_per_pp_rank = [num_layers // output_pp] * output_pp + + layer_lb = 0 + for pp in range(output_pp): + assert num_layers_per_pp_rank[pp] > 0, "each pp rank must have at least 1 layer" + layer_ub = layer_lb + num_layers_per_pp_rank[pp] + + new_sd = sd.copy() + new_sd["model"] = dict() + for k, v in sd["model"].items(): + # First pp rank has vision model. + if pp == 0 and ("vision_model" in k or "vision_projection" in k): + new_sd["model"][k] = v + continue + + # Only the first pp rank has the word embeddings. + if "language_model.embedding.word_embeddings" in k and pp == 0: + new_sd["model"][k] = v + + # Only the last pp rank has the output layer. + if "language_model.output_layer" in k and pp == output_pp - 1: + new_sd["model"][k] = v + + # Only the last pp rank has final layer norm. + if "language_model.decoder.final_layernorm" in k and pp == output_pp - 1: + new_sd["model"][k] = v + + if "language_model.decoder.layers" in k: + layer_num = int(k.split(".")[3]) + + if layer_lb <= layer_num and layer_num < layer_ub: + # On all pp ranks, megatron starts layer nums from 0! + new_layer_num = int(layer_num - layer_lb) + + k_splitted = k.split(".") + k_splitted[3] = str(new_layer_num) + new_k = ".".join(k_splitted) + + new_sd["model"][new_k] = v + + output_dir = os.path.join(base_output_dir, f"iter_0000001/mp_rank_0{tp}_00{pp}") + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "model_optim_rng.pt") + torch.save(new_sd, output_path) + + print(f"processed tp rank: {tp}/{num_tp - 1} and pp rank: {pp}/{output_pp - 1}") + + layer_lb = layer_ub + + # This is needed for megatron checkpoint loading. + with open(os.path.join(base_output_dir, "latest_checkpointed_iteration.txt"), "w") as f: + f.write("1") + + +def combine(input_dir, base_output_dir, input_pp, output_pp, num_tp, num_layers_per_pp_rank): + """Combine pipeline parallel size = N checkpoint to pipeline parallel size 1.""" + for tp in range(num_tp): + new_sd = None + + layer_num_offset = 0 + max_layer_num = 0 + + for pp in range(input_pp): + path = os.path.join(input_dir, f"mp_rank_0{tp}_00{pp}", "model_optim_rng.pt") + sd = torch.load(path) + + if pp == 0: + new_sd = sd.copy() + new_sd["model"] = dict() + new_sd["args"].pipeline_model_parallel_size = 1 + + assert new_sd is not None + + for k, v in sd["model"].items(): + # First pp rank has vision model. + if pp == 0 and ("vision_model" in k or "vision_projection" in k): + new_sd["model"][k] = v + continue + + # Only the first pp rank has the word embeddings. + if "language_model.embedding.word_embeddings" in k and pp == 0: + new_sd["model"][k] = v + + # Only the last pp rank has the output layer. + if "language_model.output_layer" in k and pp == input_pp - 1: + new_sd["model"][k] = v + + # Only the last pp rank has final layer norm. + if "language_model.decoder.final_layernorm" in k and pp == input_pp - 1: + new_sd["model"][k] = v + + if "language_model.decoder.layers" in k: + layer_num = int(k.split(".")[3]) + + # On all pp ranks, megatron starts layer nums from 0! + new_layer_num = layer_num_offset + layer_num + + if new_layer_num > max_layer_num: + max_layer_num = new_layer_num + + k_splitted = k.split(".") + k_splitted[3] = str(new_layer_num) + new_k = ".".join(k_splitted) + + new_sd["model"][new_k] = v + + print(f"processed tp rank: {tp}/{num_tp - 1} and pp rank: {pp}/{input_pp - 1}") + + layer_num_offset = max_layer_num + 1 + + output_dir = os.path.join(base_output_dir, f"iter_0000001/mp_rank_0{tp}") + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "model_optim_rng.pt") + torch.save(new_sd, output_path) + + # This is needed for megatron checkpoint loading. + with open(os.path.join(base_output_dir, "latest_checkpointed_iteration.txt"), "w") as f: + f.write("1") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Change pipeline parallelism for a model", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--input", type=str, required=True, help="Input model directory" + ) + parser.add_argument( + "--input-pipeline-parallel", type=int, required=True, help="Input model pipeline parallelism" + ) + parser.add_argument( + "--output", type=str, required=True, help="Output model directory" + ) + parser.add_argument( + "--output-pipeline-parallel", type=int, required=True, help="Output model pipeline parallelism" + ) + parser.add_argument( + "--tensor-parallel", type=int, required=True, help="Model tensor parallel size", + ) + parser.add_argument( + "--num-layers-per-pp-rank", type=int, default=None, nargs="*", help="Specify this for uneven pipeline parallel split", + ) + + args = parser.parse_args() + + f = None + if args.input_pipeline_parallel == 1 and args.output_pipeline_parallel > 1: + f = split + elif args.input_pipeline_parallel > 1 and args.output_pipeline_parallel == 1: + f = combine + else: + raise NotImplementedError("Only pipeline parallel 1 to N and N to 1 are supported") + + f(args.input, args.output, args.input_pipeline_parallel, args.output_pipeline_parallel, args.tensor_parallel, args.num_layers_per_pp_rank) + + print("done.") diff --git a/examples/multimodal/nvlm/pretrain_blend.yaml b/examples/multimodal/nvlm/pretrain_blend.yaml new file mode 100644 index 0000000000..fbbcc54388 --- /dev/null +++ b/examples/multimodal/nvlm/pretrain_blend.yaml @@ -0,0 +1,28 @@ +__module__: megatron.energon +__class__: Metadataset +splits: + train: + datasets: + - weight: 0.579 # Datasets are weighted according to their size. Weights sum up to 1. + path: + subflavors: + augmentation: False + + - weight: 0.02 + path: + subflavors: + augmentation: False + + - weight: 0.01 + path: + subflavors: + augmentation: False + + # Please refer to Table 4 in https://arxiv.org/pdf/2409.11402 for full list of pretrain datasets. + # Please refer to https://nvidia.github.io/Megatron-Energon/data_prep.html on preparing datasets in the Megatron Energon format. + val: + datasets: + - weight: 1. + path: + subflavors: + augmentation: False diff --git a/examples/multimodal/nvlm/pretrain_qwen20_72b_internvit_6b.sh b/examples/multimodal/nvlm/pretrain_qwen20_72b_internvit_6b.sh new file mode 100644 index 0000000000..008a17ac43 --- /dev/null +++ b/examples/multimodal/nvlm/pretrain_qwen20_72b_internvit_6b.sh @@ -0,0 +1,158 @@ +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export TOKENIZERS_PARALLELISM="false" + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="mcore-qwen20-72b-internvit-${DATETIME}" +else + MODEL_NAME="mcore-qwen20-72b-internvit" +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +CHECKPOINT_DIR="${WORKSPACE}/combined-qwen2.0-72b-instruct-internvit-6b-448px-1.5-tp8-te" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/pretrain_blend.yaml" + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + AD=0.0 + HD=0.0 + LI=1 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +else + MBZ=1 + BZ=2048 + NW=8 + AD=0.1 + HD=0.1 + LI=5 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +fi + +SEQ_LEN=256 # Image embeddings sequence length. +DECODER_SEQ_LEN=512 # Language model sequence length. +MAX_POS_EMBED=512 + + +OPTIONS=" \ + --use-checkpoint-args \ + --exit-duration-in-mins 230 \ + --disable-bias-linear \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model Qwen/Qwen2-72B-Instruct \ + --tokenizer-prompt-format qwen2p0 \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 1 \ + --num-layers 80 \ + --hidden-size 8192 \ + --ffn-hidden-size 29568 \ + --add-qkv-bias \ + --num-attention-heads 64 \ + --use-distributed-optimizer \ + --use-te \ + --num-workers ${NW} \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings 32768 \ + --train-samples 122880000 \ + --lr-decay-samples 25600000 \ + --lr-warmup-samples 83200 \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --lr 1e-4 \ + --min-lr 2.5e-5 \ + --lr-decay-style cosine \ + --log-interval ${LI} \ + --eval-iters 10 \ + --eval-interval 500 \ + --data-path ${DATA_TRAIN} \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --save-interval 5000 \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --split 100,0,0 \ + --clip-grad 10.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --bf16 \ + --eod-mask-loss \ + --freeze-ViT \ + --freeze-LM \ + --patch-dim 14 \ + --img-h 448 \ + --img-w 448 \ + --dataloader-type external \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --language-model-type qwen2.0_72B \ + ${EXTRA_ARGS} \ + --allow-missing-vision-projection-checkpoint \ + --vision-model-type internvit \ + --disable-vision-class-token \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --ckpt-format torch \ + --pixel-shuffle \ + --image-tag-type nvlm +" + + +export NVTE_APPLY_QK_LAYER_SCALING=0 +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi diff --git a/examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh b/examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh new file mode 100644 index 0000000000..0ec80eacc4 --- /dev/null +++ b/examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh @@ -0,0 +1,155 @@ +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export TOKENIZERS_PARALLELISM="false" + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="mcore-nous-yi34b-internvit-mlp-${DATETIME}" +else + MODEL_NAME="mcore-nous-yi34b-internvit-mlp" +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +LOAD_NAME="combined-yi-34b-internvit-tp8-mcore" +CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/pretrain_blend.yaml" + + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + LI=1 + AD=0.0 + HD=0.0 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +else + MBZ=1 + BZ=2048 + NW=8 + LI=5 + AD=0.1 + HD=0.1 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +fi + +SEQ_LEN=256 # Image embeddings sequence length. +DECODER_SEQ_LEN=512 # Language model sequence length. +MAX_POS_EMBED=512 + + +OPTIONS=" \ + --swiglu \ + --use-distributed-optimizer \ + --num-workers ${NW} \ + --num-layers 60 \ + --hidden-size 7168 \ + --normalization RMSNorm \ + --num-attention-heads 56 \ + --exit-duration-in-mins 230 \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 20480 \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model NousResearch/Nous-Hermes-2-Yi-34B \ + --tokenizer-prompt-format nvlm-yi-34b \ + --vocab-size 64000 \ + --make-vocab-size-divisible-by 1 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 5000000 \ + --disable-bias-linear \ + --tensor-model-parallel-size 8 \ + --language-model-type yi-34b \ + --vision-model-type internvit \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --train-samples 122880000 \ + --lr-decay-samples 25600000 \ + --lr-warmup-samples 83200 \ + --lr 1e-4 \ + --min-lr 2.5e-5 \ + --lr-decay-style cosine \ + --clip-grad 10.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --untie-embeddings-and-output-weights \ + --eod-mask-loss \ + --bf16 \ + --tensorboard-dir=${TENSORBOARD_DIR} \ + --freeze-LM \ + --freeze-ViT \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --data-path ${DATA_TRAIN} \ + --dataloader-type external \ + --split 100,0,0 \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --log-interval ${LI} \ + --save-interval 2000 \ + --eval-interval 500 \ + --eval-iters 10 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + ${EXTRA_ARGS} \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --allow-missing-vision-projection-checkpoint \ + --disable-vision-class-token \ + --use-te \ + --use-checkpoint-args \ + --ckpt-format torch \ + --pixel-shuffle \ + --image-tag-type nvlm + " + +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} +export NVTE_APPLY_QK_LAYER_SCALING=0 + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi diff --git a/examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh b/examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh new file mode 100755 index 0000000000..e3b001c7aa --- /dev/null +++ b/examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh @@ -0,0 +1,141 @@ +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 +export TOKENIZERS_PARALLELISM="false" + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" + +USE_TILING=0 +USE_PIXEL_SHUFFLE_ONLY=0 + +while [[ $# -gt 0 ]]; do + case $1 in + --input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + --task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + --use-tiling) + USE_TILING=1 + shift + shift + ;; + --use-pixel-shuffle-only) + USE_PIXEL_SHUFFLE_ONLY=1 + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + +SEQ_LEN=1024 # Image embeddings sequence length. +DECODER_SEQ_LEN=8192 # Language model sequence length. +MAX_POS_EMBED=8192 + +# Additional arguments. +EXTRA_ARGS="" + +if [[ $USE_TILING -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles 6 --use-thumbnail --use-tile-tags" + SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). +fi + +if [[ $USE_PIXEL_SHUFFLE_ONLY -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle" + SEQ_LEN=256 +fi + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --attention-softmax-in-fp32 \ + --no-masked-softmax-fusion \ + --swiglu \ + --num-layers 80 \ + --hidden-size 8192 \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --num-attention-heads 64 \ + --exit-on-missing-checkpoint \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 29568 \ + --load ${MODEL_PATH} \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model Qwen/Qwen2-72B-Instruct \ + --tokenizer-prompt-format qwen2p0 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --disable-bias-linear \ + --add-qkv-bias \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 1 \ + --language-model-type qwen2.0_72B \ + --vision-model-type internvit \ + --micro-batch-size 1 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --bf16 \ + --freeze-LM \ + --freeze-ViT \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --use-te \ + --transformer-impl transformer_engine \ + --use-checkpoint-args \ + --out-seq-length 16 \ + --temperature 1.0 \ + --patch-dim 14 \ + --seed 1234 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --disable-vision-class-token \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + ${EXTRA_ARGS} \ + --task ${TASK} \ + --image-tag-type nvlm \ + --ckpt-format torch +done diff --git a/examples/multimodal/nvlm/run_text_generation_qwen25_7b_internvit_video.sh b/examples/multimodal/nvlm/run_text_generation_qwen25_7b_internvit_video.sh new file mode 100755 index 0000000000..57f43347c7 --- /dev/null +++ b/examples/multimodal/nvlm/run_text_generation_qwen25_7b_internvit_video.sh @@ -0,0 +1,129 @@ +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 +export TOKENIZERS_PARALLELISM="false" + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" + +while [[ $# -gt 0 ]]; do + case $1 in + --input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + --input-metadata-path) + INPUT_METADATA_PATH="$2" + shift + shift + ;; + --num-frames) + NUM_FRAMES="$2" + shift + shift + ;; + -g|--groundtruth-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + --task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + +SEQ_LEN=256 +DECODER_SEQ_LEN=16384 + +EXTRA_ARGS=" --pixel-shuffle" + + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --attention-softmax-in-fp32 \ + --transformer-impl transformer_engine \ + --use-te \ + --use-checkpoint-args \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --language-model-type=qwen2.5_7B \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --group-query-attention \ + --num-query-groups 4 \ + --num-layers 28 \ + --hidden-size 3584 \ + --ffn-hidden-size 18944 \ + --add-qkv-bias \ + --num-attention-heads 28 \ + --max-position-embeddings 32768 \ + --no-masked-softmax-fusion \ + --load ${MODEL_PATH} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model Qwen/Qwen2.5-7B-Instruct \ + --tokenizer-prompt-format qwen2p5 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --out-seq-length 128 \ + --temperature 1.0 \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --seed 153 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --task ${TASK} \ + ${EXTRA_ARGS} \ + --special-tokens "" "" "" \ + --vision-model-type internvit \ + --num-frames ${NUM_FRAMES} \ + --ckpt-format torch +done diff --git a/examples/multimodal/nvlm/run_text_generation_qwen25_7b_siglip.sh b/examples/multimodal/nvlm/run_text_generation_qwen25_7b_siglip.sh new file mode 100755 index 0000000000..3b6221996c --- /dev/null +++ b/examples/multimodal/nvlm/run_text_generation_qwen25_7b_siglip.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 +export TOKENIZERS_PARALLELISM="false" + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" + +while [[ $# -gt 0 ]]; do + case $1 in + -i|--input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + -t|--task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + + +SEQ_LEN=256 +DECODER_SEQ_LEN=8192 +EXTRA_ARGS=" --pixel-shuffle --use-tiling --max-num-tiles 12 --use-thumbnail" + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --attention-softmax-in-fp32 \ + --transformer-impl transformer_engine \ + --use-te \ + --use-checkpoint-args \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --language-model-type=qwen2.5_7B \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --group-query-attention \ + --num-query-groups 4 \ + --num-layers 28 \ + --hidden-size 3584 \ + --ffn-hidden-size 18944 \ + --add-qkv-bias \ + --num-attention-heads 28 \ + --max-position-embeddings 32768 \ + --no-masked-softmax-fusion \ + --load ${MODEL_PATH} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model Qwen/Qwen2.5-7B-Instruct \ + --tokenizer-prompt-format qwen2p5 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --out-seq-length 128 \ + --temperature 1.0 \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --seed 153 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --task ${TASK} \ + ${EXTRA_ARGS} \ + --special-tokens "" "" "" \ + --vision-model-type siglip \ + --ckpt-format torch +done diff --git a/examples/multimodal/nvlm/run_text_generation_yi_34b_internvit_6b.sh b/examples/multimodal/nvlm/run_text_generation_yi_34b_internvit_6b.sh new file mode 100644 index 0000000000..341f4e4b0a --- /dev/null +++ b/examples/multimodal/nvlm/run_text_generation_yi_34b_internvit_6b.sh @@ -0,0 +1,140 @@ +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" + +USE_TILING=0 +USE_PIXEL_SHUFFLE_ONLY=0 + +while [[ $# -gt 0 ]]; do + case $1 in + --input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + --task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + --use-tiling) + USE_TILING=1 + shift + shift + ;; + --use-pixel-shuffle-only) + USE_PIXEL_SHUFFLE_ONLY=1 + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + +SEQ_LEN=1024 # Image embeddings sequence length. +DECODER_SEQ_LEN=8192 # Language model sequence length. +MAX_POS_EMBED=8192 + +# Additional arguments. +EXTRA_ARGS="" + +if [[ $USE_TILING -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles 6 --use-thumbnail --use-tile-tags" + SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). +fi + +if [[ $USE_PIXEL_SHUFFLE_ONLY -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle" + SEQ_LEN=256 +fi + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --attention-softmax-in-fp32 \ + --no-masked-softmax-fusion \ + --swiglu \ + --num-layers 60 \ + --hidden-size 7168 \ + --normalization RMSNorm \ + --num-attention-heads 56 \ + --exit-on-missing-checkpoint \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 20480 \ + --load ${MODEL_PATH} \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model NousResearch/Nous-Hermes-2-Yi-34B \ + --tokenizer-prompt-format nvlm-yi-34b \ + --vocab-size 64000 \ + --make-vocab-size-divisible-by 1 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 5000000 \ + --disable-bias-linear \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 1 \ + --language-model-type yi-34b \ + --vision-model-type internvit \ + --micro-batch-size 1 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --bf16 \ + --freeze-LM \ + --freeze-ViT \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --use-te \ + --transformer-impl transformer_engine \ + --use-checkpoint-args \ + --out-seq-length 16 \ + --temperature 1.0 \ + --patch-dim 14 \ + --seed 1234 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --disable-vision-class-token \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + ${EXTRA_ARGS} \ + --task ${TASK} \ + --image-tag-type nvlm \ + --ckpt-format torch +done diff --git a/examples/multimodal/nvlm/sft_34b_internvit.sh b/examples/multimodal/nvlm/sft_34b_internvit.sh new file mode 100644 index 0000000000..ca8d0a349c --- /dev/null +++ b/examples/multimodal/nvlm/sft_34b_internvit.sh @@ -0,0 +1,161 @@ +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_ALGO=^NVLS +export TOKENIZERS_PARALLELISM="false" + + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="mcore-nous-yi34b-internvit-mlp-sft-${DATETIME}" +else + MODEL_NAME="mcore-nous-yi34b-internvit-mlp-sft" +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +LOAD_NAME="mcore-nous-yi34b-internvit-mlp" # From pretraining +CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/sft_blend.yaml" + + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + LI=1 + AD=0.0 + HD=0.0 + ALLOW_NONDETERMINISTIC=1 + + # Can run out of GPU memory in interactive memory without this. + # This is just for interactive testing purposes. Do not use for proper training. + EXTRA_ARGS=" --freeze-LM" +else + MBZ=1 + BZ=128 + NW=2 + LI=5 + AD=0.0 + HD=0.0 + ALLOW_NONDETERMINISTIC=1 + + EXTRA_ARGS="" +fi + +SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). +DECODER_SEQ_LEN=3200 # Language model sequence length. +MAX_POS_EMBED=3200 + +OPTIONS=" \ + --swiglu \ + --use-distributed-optimizer \ + --num-workers ${NW} \ + --num-layers 60 \ + --hidden-size 7168 \ + --normalization RMSNorm \ + --num-attention-heads 56 \ + --exit-duration-in-mins 230 \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 20480 \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model NousResearch/Nous-Hermes-2-Yi-34B \ + --tokenizer-prompt-format nvlm-yi-34b \ + --vocab-size 64000 \ + --make-vocab-size-divisible-by 1 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 5000000 \ + --disable-bias-linear \ + --tensor-model-parallel-size 8 \ + --language-model-type yi-34b \ + --vision-model-type internvit \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --train-samples 30000000 \ + --lr-decay-samples 25600000 \ + --lr-warmup-samples 83200 \ + --lr 2e-6 \ + --min-lr 2.5e-7 \ + --lr-decay-style cosine \ + --split 100,0,0 \ + --clip-grad 10 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --untie-embeddings-and-output-weights \ + --eod-mask-loss \ + --bf16 \ + --tensorboard-dir=${TENSORBOARD_DIR} \ + --freeze-ViT \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --data-path ${DATA_TRAIN} \ + --dataloader-type external \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --log-interval ${LI} \ + --load ${FINETUNE_DIR} \ + --save ${FINETUNE_DIR} \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --save-interval 5000 \ + --eval-interval 500 \ + --eval-iters 10 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + ${EXTRA_ARGS} \ + --disable-vision-class-token \ + --use-te \ + --ckpt-format torch \ + --pixel-shuffle \ + --use-tiling \ + --max-num-tiles 6 \ + --use-thumbnail \ + --use-tile-tags \ + --image-tag-type nvlm + " + +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} +export NVTE_APPLY_QK_LAYER_SCALING=0 + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi diff --git a/examples/multimodal/nvlm/sft_blend.yaml b/examples/multimodal/nvlm/sft_blend.yaml new file mode 100644 index 0000000000..56c8230a2a --- /dev/null +++ b/examples/multimodal/nvlm/sft_blend.yaml @@ -0,0 +1,23 @@ +__module__: megatron.energon +__class__: Metadataset +splits: + train: + datasets: + - weight: 0.01 # # Datasets are weighted according to their size. Weights sum up to 1. + path: + subflavors: + augmentation: False + + - weight: 0.02 + path: + subflavors: + augmentation: False + + # Please refer to Table 6 in https://arxiv.org/pdf/2409.11402 for full list of SFT datasets. + # Please refer to https://nvidia.github.io/Megatron-Energon/data_prep.html on preparing datasets in the Megatron Energon format. + val: + datasets: + - weight: 1. + path: + subflavors: + augmentation: False diff --git a/examples/multimodal/nvlm/sft_qwen20_72b_internvit_6b.sh b/examples/multimodal/nvlm/sft_qwen20_72b_internvit_6b.sh new file mode 100644 index 0000000000..3b472259b9 --- /dev/null +++ b/examples/multimodal/nvlm/sft_qwen20_72b_internvit_6b.sh @@ -0,0 +1,165 @@ +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_ALGO=^NVLS +export TOKENIZERS_PARALLELISM="false" + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="mcore-qwen20-72b-internvit-sft-${DATETIME}" +else + MODEL_NAME="mcore-qwen20-72b-internvit-sft" +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR="${OUTPUT}/checkpoints" +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +# From pretraining. The pretraining checkpoint must be manually split to 4 pipeline parallel stages. +# Please refer to README.md and run examples/multimodal/nvlm/pp_checkpoint_converter.py. +LOAD_NAME="mcore-qwen20-72b-internvit-pp4" + +CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/sft_blend.yaml" + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + AD=0.0 + HD=0.0 + LI=1 + # This is just for interactive testing purposes. Do not use for proper training. + EXTRA_ARGS="--freeze-LM" + ALLOW_NONDETERMINISTIC=1 +else + MBZ=1 + BZ=256 + NW=8 + AD=0.0 + HD=0.0 + LI=5 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +fi + +SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). +DECODER_SEQ_LEN=3200 # Language model sequence length. +MAX_POS_EMBED=8192 + +OPTIONS=" \ + --use-checkpoint-args \ + --exit-duration-in-mins 230 \ + --disable-bias-linear \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model Qwen/Qwen2-72B-Instruct \ + --tokenizer-prompt-format qwen2p0 \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 4 \ + --num-layers 80 \ + --hidden-size 8192 \ + --ffn-hidden-size 29568 \ + --add-qkv-bias \ + --num-attention-heads 64 \ + --use-distributed-optimizer \ + --use-te \ + --num-workers ${NW} \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings 32768 \ + --train-samples 122880000 \ + --lr-decay-samples 25600000 \ + --lr-warmup-samples 83200 \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --lr 2e-6 \ + --min-lr 2.5e-7 \ + --lr-decay-style cosine \ + --log-interval ${LI} \ + --eval-iters 10 \ + --eval-interval 500 \ + --data-path ${DATA_TRAIN} \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --save-interval 10000 \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --split 100,0,0 \ + --clip-grad 10.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --bf16 \ + --eod-mask-loss \ + --freeze-ViT \ + --patch-dim 14 \ + --img-h 448 \ + --img-w 448 \ + --dataloader-type external \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --language-model-type qwen2.0_72B \ + ${EXTRA_ARGS} \ + --vision-model-type internvit \ + --disable-vision-class-token \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --ckpt-format torch \ + --pixel-shuffle \ + --use-tiling \ + --max-num-tiles 6 \ + --use-thumbnail \ + --use-tile-tags \ + --image-tag-type nvlm +" + + +export NVTE_APPLY_QK_LAYER_SCALING=0 +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi diff --git a/examples/multimodal/nvlm/sft_qwen2p5_7b_internvit_6b_video.sh b/examples/multimodal/nvlm/sft_qwen2p5_7b_internvit_6b_video.sh new file mode 100755 index 0000000000..7c88a4e1fa --- /dev/null +++ b/examples/multimodal/nvlm/sft_qwen2p5_7b_internvit_6b_video.sh @@ -0,0 +1,184 @@ +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_ALGO=^NVLS +export TOKENIZERS_PARALLELISM=false + +USER=$SLURM_JOB_USER + +# Auto-detect batch or interactive mode. +which srun +BATCH=$((1-$?)) + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="qwen2.5-7B-internvit-video-sft-nvlm-${DATETIME}" +else + MODEL_NAME="qwen2.5-7B-internvitp-video-sft-nvlm" + DEBUG=0 +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR="${OUTPUT}/checkpoints" +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +# From pretraining. The pretraining checkpoint should have tensor parallel size to 4. +LOAD_NAME="mcore-qwen2p5-7b-internvit-tp4" + +CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/sft_blend.yaml" + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + AD=0.0 + HD=0.0 + LI=1 + # This is just for interactive testing purposes. Do not use for proper training. + EXTRA_ARGS="--freeze-LM" + ALLOW_NONDETERMINISTIC=1 +else + MBZ=1 + BZ=256 + NW=8 + AD=0.0 + HD=0.0 + LI=5 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +fi + +USE_TILING=1 +SEQ_LEN=1024 +DECODER_SEQ_LEN=16384 +MAX_POS_EMBED=32768 +TRAIN_SAMPLES=6602173 +WARMUP_SAMPLES=198065 + + +if [[ $BATCH -eq 0 ]]; then + # Runs out of GPU memory in interactive memory without this. + EXTRA_ARGS+="--freeze-LM" +fi + +if [[ $USE_TILING -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles 12 --use-thumbnail" + SEQ_LEN=256 +fi + + +OPTIONS=" \ + --swiglu \ + --use-distributed-optimizer \ + --num-workers ${NW} \ + --num-layers 28 \ + --hidden-size 3584 \ + --norm-epsilon 1e-06 \ + --normalization RMSNorm \ + --num-attention-heads 28 \ + --exit-duration-in-mins 110 \ + --group-query-attention \ + --num-query-groups 4 \ + --ffn-hidden-size 18944 \ + --add-qkv-bias \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --dataloader-seq-length ${DECODER_SEQ_LEN} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model Qwen/Qwen2.5-7B-Instruct \ + --tokenizer-prompt-format qwen2p5 \ + --pixel-shuffle \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --disable-bias-linear \ + --pipeline-model-parallel-size 1 \ + --tensor-model-parallel-size 4 \ + --language-model-type qwen2.5_7B \ + --vision-model-type internvit \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --lr 2e-6 \ + --min-lr 2.5e-7 \ + --train-samples ${TRAIN_SAMPLES} \ + --lr-warmup-samples ${WARMUP_SAMPLES} \ + --lr-decay-style cosine \ + --clip-grad 10 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --eod-mask-loss \ + --bf16 \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --data-path ${DATA_TRAIN} \ + --dataloader-type external \ + --split 100,0,0 \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --log-interval ${LI} \ + --save-interval 500 \ + --eval-interval 500 \ + --eval-iters 10 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + ${EXTRA_ARGS} \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --distributed-timeout-minutes 60 \ + --allow-missing-vision-projection-checkpoint \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --disable-vision-class-token \ + --use-te \ + --ckpt-format torch \ + --num-frames 32 \ + --use-checkpoint-args \ + --image-tag-type internvl \ + --recompute-granularity full \ + --recompute-method block \ + --recompute-num-layers 28 \ + --recompute-vision \ +" + + +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} +export NVTE_APPLY_QK_LAYER_SCALING=0 + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi diff --git a/examples/multimodal/pretrain_dataset.yaml b/examples/multimodal/pretrain_dataset.yaml new file mode 100644 index 0000000000..f27bccba30 --- /dev/null +++ b/examples/multimodal/pretrain_dataset.yaml @@ -0,0 +1,15 @@ +__module__: megatron.energon +__class__: Metadataset +splits: + train: + datasets: + - weight: 1. + path: + subflavors: + augmentation: false + val: + datasets: + - weight: 1. + path: + subflavors: + augmentation: false diff --git a/examples/multimodal/pretrain_mistral_clip.sh b/examples/multimodal/pretrain_mistral_clip.sh new file mode 100755 index 0000000000..90b0053d19 --- /dev/null +++ b/examples/multimodal/pretrain_mistral_clip.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# Pretrain a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +MODEL_NAME="mcore-llava-mistral-7b-instruct-clip336-pretraining" + +# Check that the user has set an output path for model checkpoints. +if [[ -z $WORKSPACE ]]; then + echo "Please set WORKSPACE for storing your model checkpoints." + exit 1 +fi + +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +if [[ -z $LOAD_NAME ]]; then + echo "Please set LOAD_NAME for input model name." + exit 1 +fi + +CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}/checkpoints" + +DATA_TRAIN="${SOURCE}/examples/multimodal/pretrain_dataset.yaml" + +DEBUG=0 +if [[ $DEBUG -eq 1 ]]; then + BZ=32 + NW=2 + HD=0.0 + LI=1 + EXTRA_ARGS="" + NONDETERMINISTIC_ATTN=1 +else + BZ=256 + NW=2 + HD=0.1 + LI=10 + EXTRA_ARGS="" + NONDETERMINISTIC_ATTN=1 +fi + +OPTIONS=" \ + --apply-layernorm-1p \ + --attention-softmax-in-fp32 \ + --use-checkpoint-args \ + --use-distributed-optimizer \ + --transformer-impl transformer_engine \ + --use-te \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --num-workers ${NW} \ + --exit-duration-in-mins 230 \ + --use-flash-attn \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout ${HD} \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --seq-length 576 \ + --decoder-seq-length 1024 \ + --max-position-embeddings 4096 \ + --ffn-hidden-size 14336 \ + --train-iters 20000 \ + --micro-batch-size 1 \ + --global-batch-size ${BZ} \ + --lr-decay-iters 20000 \ + --lr-warmup-fraction .01 \ + --lr 0.00015 \ + --min-lr 1.0e-5 \ + --lr-decay-style cosine \ + --log-interval ${LI} \ + --eval-iters 10 \ + --eval-interval 1000 \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model mistralai/Mistral-7B-Instruct-v0.3 \ + --tokenizer-prompt-format mistral \ + --data-path ${DATA_TRAIN} \ + --prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \ + --save-interval 1000 \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --split 100,0,0 \ + --clip-grad 1.0 \ + --weight-decay 1e-2 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --bf16 \ + --eod-mask-loss \ + --freeze-LM \ + --freeze-ViT \ + --patch-dim 14 \ + --img-h 336 \ + --img-w 336 \ + --dataloader-type external \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --language-model-type=mistral_7b \ + --disable-vision-class-token \ + ${EXTRA_ARGS} \ + --distributed-timeout-minutes 60 \ + --allow-missing-vision-projection-checkpoint \ + --ckpt-format torch +" + +export NVTE_APPLY_QK_LAYER_SCALING=0 +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN} + +torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} \ No newline at end of file diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py new file mode 100644 index 0000000000..c9200b77f7 --- /dev/null +++ b/examples/multimodal/run_text_generation.py @@ -0,0 +1,589 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Generate text using a vision language model.""" +import json +import logging +import os +import sys +from functools import partial + +# Add megatron to the path. +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) +) + +import torch +import yaml +from config import EvaluationConfig +from evaluation.evaluation_datasets import get_evaluation_dataset +from model import model_provider +from multimodal_args import add_multimodal_extra_args + +from megatron.core import parallel_state +from megatron.core.enums import ModelType +from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN +from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings +from megatron.inference.text_generation.api import generate_and_post_process +from megatron.inference.text_generation.forward_step import ForwardStep +from megatron.inference.text_generation.communication import broadcast_int_list +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.engines.mcore_engine import MCoreEngine +from megatron.core.inference.inference_request import InferenceRequest, VLMInferenceRequest +from megatron.core.inference.text_generation_controllers.vlm_text_generation_controller import ( + VLMTextGenerationController, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.inference.model_inference_wrappers.multimodal.vlm_inference_wrapper import ( + VLMInferenceWrapper, +) +from megatron.training import get_args, get_model, get_tokenizer, print_rank_0 +from megatron.training.checkpointing import load_checkpoint +from megatron.training.initialize import initialize_megatron + + +def add_text_generation_args(parser): + """Text generation arguments.""" + group = parser.add_argument_group(title='Vision language model text generation arguments') + + group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') + group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') + group.add_argument("--top_k", type=int, default=0, help='Top k sampling.') + group.add_argument( + "--out-seq-length", type=int, default=128, help='Length of the output generated text.' + ) + group.add_argument("--output-path", type=str, help='Output file path') + group.add_argument('--input-image-path', type=str, help="Input image directory") + group.add_argument( + '--num-partitions', type=int, default=0, help="Number of partitions for inputs." + ) + group.add_argument('--partition-id', type=int, default=0, help="Partition index") + group.add_argument("--gt-path", type=str, help="Optional ground truth file") + group.add_argument( + "--task", + type=str, + choices=[ + "captioning", + "TextVQA", + "VQAv2", + "ChartQA", + "MMMU", + "VideoMME", + "OCRBench", + "MathVista", + "AI2D", + ], + help="Generation task to run", + ) + group.add_argument( + "--num-samples-per-partition", type=int, default=0, help="Number of samples per partition" + ) + group.add_argument("--config-path", type=str, help="Evaluation config file to use.") + + group.add_argument("--use-mcore-inference", action="store_true", default=False, help="Use the MCore inference API") + + # Add common multimodal arguments needed for e.g. building the model. + parser = add_multimodal_extra_args(parser) + + return parser + + +def get_evaluation_dataloader( + task, + input_image_path, + gt_path, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_samples_per_partition, + num_partitions, + partition_id, + num_frames, + num_workers, + vision_model_type, +): + """Build evaluation dataset.""" + dataset = get_evaluation_dataset( + task, + input_image_path, + gt_path, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_samples_per_partition, + num_partitions, + partition_id, + num_frames, + vision_model_type, + ) + + dp_rank = parallel_state.get_data_parallel_rank() + dp_world_size = parallel_state.get_data_parallel_world_size() + + sampler = torch.utils.data.DistributedSampler( + dataset, shuffle=False, num_replicas=dp_world_size, rank=dp_rank + ) + # TODO: Batched inference is not supported yet. + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=None, num_workers=num_workers, sampler=sampler, pin_memory=True + ) + + return dataloader + + +def generate_samples(model, config: EvaluationConfig, print_output): + """Text generation using a trained vision language model.""" + args = get_args() + + dataloader = get_evaluation_dataloader( + config.task, + config.input_image_path, + config.gt_path, + args.img_h, + args.img_w, + args.use_tiling, + args.max_num_tiles, + args.use_thumbnail, + config.num_samples_per_partition, + config.num_partitions, + config.partition_id, + args.num_frames, + args.num_workers, + args.vision_model_type, + ) + + num_img_embeddings_per_tile = get_num_image_embeddings( + args.img_h, + args.img_w, + args.patch_dim, + args.vision_model_type, + args.disable_vision_class_token, + 1, + args.pixel_shuffle, + args.use_tile_tags, + ) + + if args.use_mcore_inference: + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=args.hidden_size, + inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold, + fp32_residual_connection=args.fp32_residual_connection, + params_dtype=args.params_dtype, + padded_vocab_size=args.padded_vocab_size, + ) + inference_wrapped_model = VLMInferenceWrapper(model, inference_wrapper_config) + tokenizer = get_tokenizer() + controller = VLMTextGenerationController( + inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer + ) + inference_engine = MCoreEngine( + controller, max_batch_size=1, random_seed=args.seed + ) + sampling_params = SamplingParams( + temperature=config.temperature, + top_k=config.top_k, + top_p=config.top_p, + num_tokens_to_generate=config.out_seq_length, + ) + + for idx, (imgs, num_tiles, sample_id, question, answers, metadata) in enumerate(dataloader): + imgs = imgs.to("cuda") + num_tiles = num_tiles.to("cuda") + + conv = get_conversation(config.task, question) + + if not args.use_mcore_inference: + forward_step = partial(VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles, args.decoder_seq_length) + + + if is_first_rank(): + + if args.use_mcore_inference: + inference_request = VLMInferenceRequest( + request_id=inference_engine.get_new_request_id(), + prompt=conv, + prompt_tokens=controller.tokenize_prompt(conv), + inference_parameters=sampling_params, + num_img_embeddings_per_tile=num_img_embeddings_per_tile, + imgs=imgs, + num_tiles=num_tiles, + decoder_seq_length=args.decoder_seq_length, + ) + results: List[InferenceRequest] = inference_engine.generate( + inference_requests=[inference_request] + ) + + resp_sentences = [ + tokenizer.detokenize(result.prompt_tokens) + result.generated_text + for result in results + ] + else: + resp_sentences, _, _, _ = generate_and_post_process( + model, + forward_step=forward_step, + prompts=[conv], + tokens_to_generate=config.out_seq_length, + top_k_sampling=config.top_k, + top_p_sampling=config.top_p, + add_BOS=False, + temperature=config.temperature, + random_seed=args.seed, + detokenize_segments=False, + data_parallel=True, + ) + + for generation in resp_sentences: + if isinstance(sample_id, torch.Tensor): + sample_id = sample_id.item() + + output = {"sample_id": sample_id} + + output_name = "" + if config.task == "captioning": + output_name = "caption" + elif config.task in ( + "TextVQA", + "VQAv2", + "ChartQA", + "OCRBench", + "MathVista", + "AI2D", + ): + output_name = "answer" + elif config.task in ("MMMU"): + output_name = "text" + elif config.task == "VideoMME": + output_name = "response" + output = question + else: + raise NotImplementedError("no output name defined for", config.task) + + prompt, generated = get_prompt_and_generated( + generation, args.tokenizer_prompt_format + ) + if config.task == "VideoMME": + output["questions"][0][output_name] = generated + else: + output["prompt"] = prompt + output[output_name] = generated + + if config.task == "captioning": + output["ground_truth"] = answers + elif config.task in ( + "TextVQA", + "VQAv2", + "ChartQA", + "OCRBench", + "MathVista", + "AI2D", + ): + if isinstance(answers, str): + answers = [answers] + output["gt_answer"] = answers + + if len(metadata) > 0: + output.update(metadata) + elif config.task == "MMMU": + output["prediction"] = generated + output.update(metadata) + else: + raise NotImplementedError("no output processing defined for", config.task) + + if print_output: + print(output) + + yield output + idx += 1 + else: + if args.use_mcore_inference: + inference_request = VLMInferenceRequest( + request_id=inference_engine.get_new_request_id(), + prompt=conv, + prompt_tokens=controller.tokenize_prompt(conv), + inference_parameters=sampling_params, + num_img_embeddings_per_tile=num_img_embeddings_per_tile, + imgs=imgs, + num_tiles=num_tiles, + decoder_seq_length=args.decoder_seq_length, + ) + inference_engine.generate( + inference_requests=[inference_request] + ) + else: + generate_and_post_process( + model, forward_step=forward_step, detokenize_segments=False, data_parallel=True + ) + + idx += 1 + + +def get_evaluation_config(): + """Get evaluation config from a config file or command-line arguments.""" + args = get_args() + if args.config_path: + with open(args.config_path, "r") as f: + config_dict = yaml.safe_load(f) + + config = EvaluationConfig(**config_dict) + else: + config = EvaluationConfig( + task=args.task, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + out_seq_length=args.out_seq_length, + output_path=args.output_path, + input_image_path=args.input_image_path, + gt_path=args.gt_path, + num_partitions=args.num_partitions, + partition_id=args.partition_id, + num_samples_per_partition=args.num_samples_per_partition, + ) + + # Default output path if not defined... + if not config.output_path: + os.makedirs("generated", exist_ok=True) + config.output_path = "generated/" + args.language_model_type + + return config + + +def is_first_rank(): + """First tensor and pipeline parallel rank.""" + return ( + parallel_state.is_pipeline_first_stage(ignore_virtual=True) + and parallel_state.get_tensor_model_parallel_rank() == 0 + ) + + +def get_output_path(config, dp_rank): + """Generation output path.""" + return ( + f"{config.output_path}-{config.task}-dprank={dp_rank}-partition={config.partition_id}.jsonl" + ) + + +def generate_and_write_samples(model, config, print_output=True): + """Generate text and write to an output file.""" + dp_rank = parallel_state.get_data_parallel_rank() + + if is_first_rank(): + output_path = get_output_path(config, dp_rank) + output_file = open(output_path, "w") + print(f"output path: {output_file.name}") + + with torch.no_grad(): + for output in generate_samples(model, config, print_output): + if is_first_rank(): + output_file.write(json.dumps(output) + "\n") + output_file.flush() + + if is_first_rank(): + output_file.close() + +class VLMForwardStep(ForwardStep): + """Inference forward step for a multimodal model.""" + + def __init__( + self, + num_img_embeddings_per_tile, + images, + num_tiles, + decoder_seq_length, + model, + max_batch_size, + max_sequence_length, + ): + """Create multimodal forward step.""" + total_num_tiles = torch.sum(num_tiles).item() + num_img_embeddings = num_img_embeddings_per_tile * total_num_tiles + + super().__init__(model, max_batch_size, max_sequence_length + num_img_embeddings) + self._images = images + self._num_tiles = num_tiles + self._num_img_embeddings = num_img_embeddings + self.decoder_seq_length = decoder_seq_length + + self._recv_only_vision_embeds = False + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + # Checks if the previous stage only has a vision encoder, and that the current stage has part of the LM decoder. + # In this case, the current stage should only receive vision embeddings. + if pp_rank > 0: + self._recv_only_vision_embeds = parallel_state.is_inside_encoder(pp_rank - 1) and (not parallel_state.is_inside_decoder(pp_rank - 1)) and parallel_state.is_inside_decoder() + + # Checks if the current stage only has a vision encoder + self._encoder_only = parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder() + + def _forward(self, tokens, position_ids, attention_mask): + return self.model( + self._images, + tokens, + position_ids, + attention_mask=None, + inference_params=self.inference_params, + num_image_tiles=self._num_tiles, + runtime_gather_output=True, + ) + + def __call__(self, tokens, position_ids, attention_mask): + num_image_tokens = (tokens == self.model.module.image_token_index).sum().item() + num_tokens = tokens.size(1) + recv_buffer_seq_length = None + if num_image_tokens > 0: + # When there are image tokens and this stage only receives vision embeddings, adjust the recv buffer seq length to match the image embeddings sequence length. + # If there are image tokens and this stage receives full embeddings, make sure we compensate for expansion of image tokens. + # Note that this will set a recv_buffer_seq_length for the encoder stage, this length is irrelevant since that recv buffer is never allocated. + if self._recv_only_vision_embeds: + recv_buffer_seq_length = self._num_img_embeddings + else: + recv_buffer_seq_length = min(self._num_img_embeddings + num_tokens - num_image_tokens, self.decoder_seq_length) + elif self._recv_only_vision_embeds: + # If this stage only receives vision embeddings and there are no image tokens we won't run the encoder and therefore shouldn't try to recv. + recv_buffer_seq_length = 0 + + # If the pipeline stage only has a vision encoder, then it only needs to run when there are image tokens + if not (self._encoder_only and num_image_tokens == 0): + output = super().__call__(tokens, position_ids, attention_mask, recv_buffer_seq_length=recv_buffer_seq_length) + else: + output = None + if isinstance(output, tuple): + logits, _ = output + else: + logits = output + + # On the first inference iteration, we compute image tokens. + # On every PP stage(although inference params should only matter for decoder), + # update the sequence length offset by the number of image tokens. + if num_tokens > 1 and num_image_tokens > 0: + if "image_tokens_count" not in self.inference_params.key_value_memory_dict: + self.inference_params.key_value_memory_dict["image_tokens_count"] = self._num_img_embeddings + + if self._num_img_embeddings + num_tokens - num_image_tokens > self.decoder_seq_length: + self.inference_params.sequence_len_offset += self.decoder_seq_length - num_tokens + else: + self.inference_params.sequence_len_offset += ( + self.inference_params.key_value_memory_dict["image_tokens_count"] - num_image_tokens + ) + + return logits + + +def get_conversation(task, question): + """Get a conversation for a given task and evaluation question.""" + conversation = [] + + # In all cases, the tokenizer adds possible header tokens for the assistant. + if task == "captioning": + conversation = [ + {"role": "system", "content": "Answer the questions."}, + { + "role": "user", + "content": f"{IMAGE_TOKEN}\nProvide a one-sentence caption for provided image.", + }, + ] + elif task in ("TextVQA", "VQAv2", "ChartQA"): + conversation = [ + {"role": "system", "content": "Answer the questions."}, + { + "role": "user", + "content": f"{IMAGE_TOKEN}\n{question}\nAnswer the question using a single word or phrase.", + }, + ] + elif task in ("OCRBench", "MathVista", "AI2D"): + conversation = [ + {"role": "system", "content": "Answer the questions."}, + {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"}, + ] + elif task == "MMMU": + conversation = [ + {"role": "system", "content": "Answer the questions."}, + {"role": "user", "content": question}, + ] + elif task == "VideoMME": + q = ( + "Select the best answer to the following multiple-choice " + "question based on the video. Respond with only the letter " + "(A, B, C, or D) of the correct option.\n" + ) + q += question["questions"][0]["question"] + "\n" + q += question["questions"][0]["choices"][0] + "\n" + q += question["questions"][0]["choices"][1] + "\n" + q += question["questions"][0]["choices"][2] + "\n" + q += question["questions"][0]["choices"][3] + "\n" + + conversation = [ + {"role": "system", "content": "Answer the questions."}, + {"role": "user", "content": f"{IMAGE_TOKEN}\n{q}"}, + ] + + return conversation + + +def get_prompt_and_generated(prompt_and_generation, prompt_format): + """Strip prompt and other unnecessary text from generation.""" + if prompt_format in ("llama3", "llama3p1"): + splitted = prompt_and_generation.split("<|start_header_id|>assistant<|end_header_id|>\n\n") + prompt = splitted[0] + generated = splitted[1] + generated = generated.split("<|eot_id|>")[0] + elif prompt_format == "mistral": + splitted = prompt_and_generation.split("[/INST]") + prompt = splitted[0] + generated = splitted[1] + generated = generated.split("")[0] + elif prompt_format == "chatml": + splitted = prompt_and_generation.split("<|im_start|> assistant\n") + prompt = splitted[0] + generated = splitted[1] + generated = generated.split("<|im_end|>")[0] + elif prompt_format in ("nvlm-yi-34b", "qwen2p0", "qwen2p5"): + splitted = prompt_and_generation.split("<|im_start|>assistant\n") + prompt = splitted[0] + generated = splitted[1] + generated = generated.split("<|im_end|>")[0] + else: + raise ValueError(f"Prompt format {prompt_format} is not supported.") + + # Remove possible garbage. + generated = generated.strip() + generated = generated.split("\n\n")[0] + generated = generated.split("\n")[0] + + return prompt, generated + + +def main(): + """Vision language model text generation.""" + initialize_megatron(extra_args_provider=add_text_generation_args) + + if torch.distributed.get_rank() == 0: + logging.getLogger(__name__).warning( + "Models using pipeline parallelism are not supported yet." + ) + + args = get_args() + + def wrapped_model_provider(pre_process, post_process, add_encoder, add_decoder): + return model_provider(pre_process, post_process, add_encoder, add_decoder, parallel_output=False) + + # Set up model and load checkpoint. + model = get_model(wrapped_model_provider, model_type=ModelType.encoder_and_decoder, wrap_with_ddp=False) + + if args.load is not None: + _ = load_checkpoint(model, None, None) + + model = model[0] + + model.eval() + + config = get_evaluation_config() + + generate_and_write_samples(model, config) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/sft_dataset.yaml b/examples/multimodal/sft_dataset.yaml new file mode 100644 index 0000000000..c9f0257ae7 --- /dev/null +++ b/examples/multimodal/sft_dataset.yaml @@ -0,0 +1,15 @@ +__module__: megatron.energon +__class__: Metadataset +splits: + train: + datasets: + - weight: 1. + path: + subflavors: + augmentation: false + val: + datasets: + - weight: 1. + path: + subflavors: + augmentation: false diff --git a/examples/multimodal/sft_mistral_clip.sh b/examples/multimodal/sft_mistral_clip.sh new file mode 100755 index 0000000000..94ff208eb4 --- /dev/null +++ b/examples/multimodal/sft_mistral_clip.sh @@ -0,0 +1,130 @@ +#!/bin/bash +# Run SFT on a pretrained multimodal model + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +MODEL_NAME="mcore-llava-mistral-7b-instruct-clip336-sft" + +# Check that the user has set an output path for model checkpoints. +if [[ -z $WORKSPACE ]]; then + echo "Please set WORKSPACE for storing your model checkpoints." + exit 1 +fi + +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +if [[ -z $LOAD_NAME ]]; then + echo "Please set LOAD_NAME for input model name." + exit 1 +fi + +if [[ -z $LOAD_ITER ]]; then + echo "Please set LOAD_ITER for pre-trained input model iteration." + exit 1 +fi + +CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}/checkpoints" + +DATA_TRAIN="${SOURCE}/examples/multimodal/sft_dataset.yaml" + +DEBUG=0 +if [[ $DEBUG -eq 1 ]]; then + BZ=8 + NW=1 + HD=0.0 + LI=1 + EXTRA_ARGS="" + NONDETERMINISTIC_ATTN=1 +else + BZ=128 + NW=2 + HD=0.1 + LI=10 + EXTRA_ARGS="" + NONDETERMINISTIC_ATTN=1 +fi + +OPTIONS=" \ + --apply-layernorm-1p \ + --attention-softmax-in-fp32 \ + --use-checkpoint-args \ + --use-distributed-optimizer \ + --transformer-impl transformer_engine \ + --use-te \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --num-workers ${NW} \ + --exit-duration-in-mins 230 \ + --use-flash-attn \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout ${HD} \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --seq-length 576 \ + --decoder-seq-length 2048 \ + --max-position-embeddings 4096 \ + --ffn-hidden-size 14336 \ + --train-iters 20000 \ + --micro-batch-size 1 \ + --global-batch-size ${BZ} \ + --lr-decay-iters 20000 \ + --lr-warmup-fraction .01 \ + --lr 1e-6 \ + --min-lr 1e-7 \ + --lr-decay-style cosine \ + --log-interval ${LI} \ + --eval-iters 10 \ + --eval-interval 500 \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model mistralai/Mistral-7B-Instruct-v0.3 \ + --tokenizer-prompt-format mistral \ + --data-path ${DATA_TRAIN} \ + --prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \ + --save-interval 500 \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --split 100,0,0 \ + --clip-grad 0.5 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --eod-mask-loss \ + --freeze-ViT \ + --patch-dim 14 \ + --img-h 336 \ + --img-w 336 \ + --dataloader-type external \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --language-model-type=mistral_7b \ + --disable-vision-class-token \ + ${EXTRA_ARGS} \ + --distributed-timeout-minutes 60 \ + --ckpt-format torch +" + +export NVTE_APPLY_QK_LAYER_SCALING=0 +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN} + +torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} diff --git a/examples/multimodal/text_generation_mistral_clip.sh b/examples/multimodal/text_generation_mistral_clip.sh new file mode 100755 index 0000000000..c1ef7bcee8 --- /dev/null +++ b/examples/multimodal/text_generation_mistral_clip.sh @@ -0,0 +1,109 @@ +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" +NUM_FRAMES=1 + +while [[ $# -gt 0 ]]; do + case $1 in + -i|--input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + --num-frames) + NUM_FRAMES="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + -t|--task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --apply-layernorm-1p \ + --attention-softmax-in-fp32 \ + --use-flash-attn \ + --transformer-impl transformer_engine \ + --use-te \ + --use-checkpoint-args \ + --normalization RMSNorm \ + --language-model-type mistral_7b \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --group-query-attention \ + --num-query-groups 8 \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 14336 \ + --num-attention-heads 32 \ + --max-position-embeddings 4096 \ + --no-masked-softmax-fusion \ + --load ${MODEL_PATH} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model mistralai/Mistral-7B-Instruct-v0.3 \ + --tokenizer-prompt-format mistral \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length 2048 \ + --out-seq-length 12 \ + --temperature 1.0 \ + --img-h 336 \ + --img-w 336 \ + --patch-dim 14 \ + --seed 153 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --task ${TASK} \ + --disable-vision-class-token \ + --num-frames ${NUM_FRAMES} \ + --ckpt-format torch +done diff --git a/examples/multimodal/train.py b/examples/multimodal/train.py new file mode 100644 index 0000000000..a81a2f3a26 --- /dev/null +++ b/examples/multimodal/train.py @@ -0,0 +1,416 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Pretrain or SFT multimodal.""" +import math +import os +import sys +from functools import partial + +import torch +import yaml + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) +) + +from dataloader_provider import train_valid_test_dataloaders_provider, is_first_or_last_stage +from model import model_provider +from multimodal_args import add_multimodal_extra_args + +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.core.models.multimodal import context_parallel +from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, LLaVAModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import ( + get_tensor_model_parallel_rank, + get_pipeline_model_parallel_world_size, + is_pipeline_last_stage, +) +from megatron.training import get_args, get_timers, get_tokenizer, pretrain +from megatron.training.utils import is_last_rank, get_batch_on_this_cp_rank + + +def get_batch(data_iterator, image_token_index, img_seq_len): + """Generate a batch + + Note: attn_mask_type in layer_specs.py sets the attention mask. Attention mask is None here. + """ + imgs = None + tokens = None + labels = None + loss_mask = None + attention_mask = None + position_ids = None + num_tiles = None + packed_seq_params = None + + args = get_args() + + # Dataloader doesn't run on the middle stages in a pipeline parallel model. + pp_size = get_pipeline_model_parallel_world_size() + if not is_first_or_last_stage(pp_size, args.encoder_pipeline_model_parallel_size): + # Note these are all set to None above. + return tokens, labels, loss_mask, attention_mask, position_ids, imgs, num_tiles, packed_seq_params + + # Broadcast data. + torch.cuda.nvtx.range_push("get_data") + if data_iterator is not None and get_tensor_model_parallel_rank() == 0: + data = next(data_iterator) + else: + data = None + + data_text = tensor_parallel.broadcast_data(["tokens"], data, torch.int64)["tokens"] + labels = tensor_parallel.broadcast_data(["labels"], data, torch.int64)["labels"] + + imgs = tensor_parallel.broadcast_data(["imgs"], data, torch.float32)["imgs"] + num_tiles = tensor_parallel.broadcast_data(["num_tiles"], data, torch.int32)["num_tiles"] + + cu_lengths = tensor_parallel.broadcast_data(["cu_lengths"], data, torch.int32)["cu_lengths"] + max_lengths = tensor_parallel.broadcast_data(["max_lengths"], data, torch.int32)["max_lengths"] + + # No image input (text-only sample) if the dataloader returned a size 1 image. + if imgs.shape == torch.Size([1, 1]): + # FSDP can hang with text-only samples. A workaround is to run a valid dummy image through the vision + # model and then add image embeddings with a zero multiplier. + if args.use_torch_fsdp2: + imgs = torch.zeros((1, 3, args.img_h, args.img_w), dtype=torch.float32, device=data_text.device) + num_tiles = torch.tensor([], dtype=torch.int, device=data_text.device) + else: + # Similar workaround is not needed without FSDP and we can use an empty image. + # FIXME: text-only data can cause still cause a hang in the special case where + # the vision model is own its own pipeline rank and --freeze-ViT is enabled. + imgs = torch.tensor([], dtype=torch.float32, device=data_text.device) + num_tiles = torch.tensor([], dtype=torch.int, device=data_text.device) + + # Last pipeline parallel stage doesn't need images. + if pp_size > 1 and is_pipeline_last_stage(): + imgs = None + + # If cu_lengths and max_lengths are non-dummy, construct PackedSeqParams. Otherwise, leave it at None. + if cu_lengths.shape != torch.Size([1, 1]): + assert ( + cu_lengths.shape[0] == max_lengths.shape[0] == 1 + ), "micro-batch-size must be 1 for packing" + cu_lengths = cu_lengths[0] + max_lengths = max_lengths[0] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_lengths, + cu_seqlens_kv=cu_lengths, + max_seqlen_q=max_lengths, + max_seqlen_kv=max_lengths, + ) + + torch.cuda.nvtx.range_pop() + + tokens_ = data_text.long() + + torch.cuda.nvtx.range_push("index tokens") + tokenizer = get_tokenizer() + text_length = tokens_.shape[1] + tokens = tokens_[:, :text_length].contiguous() + labels = labels[:, 1 : text_length + 1].contiguous() + + assert tokens.shape == labels.shape, f"tokens: {tokens.shape} != labels: {labels.shape}" + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("get_ltor_masks_and_position_ids") + loss_mask, position_ids = get_ltor_masks_and_position_ids(tokens, labels, tokenizer.pad) + torch.cuda.nvtx.range_pop() + + # If context parallel is enabled, must shard inputs to CP ranks. + if args.context_parallel_size > 1 or args.sequence_parallel: + assert tokens.shape[0], "micro-batch-size > 1 not supported yet with CP" + + num_image_tokens = torch.sum(tokens == image_token_index).item() + num_image_embeddings = num_image_tokens * img_seq_len - num_image_tokens + seq_len = text_length + num_image_embeddings + + # CP expects sequence length is divisible by CP size so apply padding. + mp_padding_needed = context_parallel.get_padding( + seq_len, args.context_parallel_size, + args.tensor_model_parallel_size, args.sequence_parallel, + ) + tokens, position_ids, labels, loss_mask = [torch.nn.functional.pad(item, (0, mp_padding_needed)) for item in (tokens, position_ids, labels, loss_mask)] + + # Get PackedSeqParams that indicate the amount of padding for TransformerEngine. + packed_seq_params = context_parallel.get_packed_seq_params(tokens, num_image_embeddings, mp_padding_needed, args.context_parallel_size, True) + + return ( + tokens, + labels, + loss_mask, + attention_mask, + position_ids, + imgs, + num_tiles, + packed_seq_params, + ) + + +def get_ltor_masks_and_position_ids(input_ids, target, pad_token): + """Build masks and position id for left to right model.""" + seq_length = input_ids.shape[1] + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + # Loss mask. + loss_mask = torch.ones(target.size(), dtype=torch.float, device=input_ids.device) + loss_mask[target == pad_token] = 0.0 # mask paddings + loss_mask[target == IGNORE_INDEX] = 0.0 # mask prompts + + return loss_mask, position_ids + + +def get_mask_start_and_end_idx(arr): + """ + Returns a list of tuples holding the start and end index in arr of the non-zeros contiguuous + sub arrays. + + For instance, if arr = [0, 1, 0, 0, 1, 1] + get_mask_start_and_end_idx(arr) = [(1, 1), (4, 5)] + such that arr[1:1+1] = [1] and arr[4:5+1] = [1, 1] + """ + mask = (arr != 0) + + mask_int = mask.int() + + diff = mask_int[1:] - mask_int[:-1] + start_indices = (diff == 1).nonzero(as_tuple=False).flatten() + 1 + end_indices = (diff == -1).nonzero(as_tuple=False).flatten() + if len(mask)==0: return [] + if mask[0]: + start_indices = torch.cat((torch.tensor([0], device=arr.device), start_indices)) + if mask[-1]: + end_indices = torch.cat((end_indices, torch.tensor([len(arr) - 1], device=arr.device))) + sequences = list(zip(start_indices.tolist(), end_indices.tolist())) + return sequences + + +def scaled_loss_func(loss_mask, output_tensor): + """ + Scaled loss function + + Scale the loss for each conversation turn using the formula: + + 1 / sum_j[ sqrt(length(loss_turn_j)) ] * sum_i[ sum(loss_turn_i) / sqrt(length(loss_turn_i)) ] + + Where we use the loss mask to infer the start / end of the conversation turns. + """ + losses = output_tensor.float() + + loss_list = [] + num_valid_labels_list = [] + for idx in range(losses.shape[0]): + loss_this_sample = losses[idx] + turn_start_end_list = get_mask_start_and_end_idx(loss_mask[idx]) + for turn_start, turn_end in turn_start_end_list: + # compute loss for each turn + loss_this_turn = loss_this_sample[turn_start:turn_end+1].sum() + assert (1 - loss_mask)[idx][turn_start:turn_end+1].sum() < 1.0 + num_valid_labels_this_turn = turn_end - turn_start + 1 + loss_this_turn = loss_this_turn / num_valid_labels_this_turn + loss_list.append(loss_this_turn) + # append num of valid labels for each turn + num_valid_labels_list.append(num_valid_labels_this_turn) + base_num = sum([math.sqrt(each) for each in num_valid_labels_list]) + for idx in range(len(loss_list)): + # normalize loss for each turn + loss_list[idx] = loss_list[idx] * math.sqrt(num_valid_labels_list[idx]) / base_num + + total_loss = torch.stack(loss_list).sum() + total_tokens = torch.ones_like(total_loss) + + loss = torch.cat([total_loss.view(1), total_tokens.view(1)]) + + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + + return ( + total_loss, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) + + +def loss_func(loss_mask, output_tensor): + args = get_args() + + losses = output_tensor.float() + + loss_mask = loss_mask.contiguous().view(-1).float() + + total_tokens = loss_mask.sum() + total_loss = torch.sum(losses.view(-1) * loss_mask) + loss = torch.cat([total_loss.view(1), total_tokens.view(1)]) + + if args.context_parallel_size > 1: + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + + # We multiply by context parallel size because later there will be a divide by CP(+DP) size. + return ( + loss[0] * args.context_parallel_size, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])} + ) + + +def forward_step(data_iterator, model: LLaVAModel): + """Forward training step. + + Args: + data_iterator (torch.utils.data.dataloader): Input data iterator + model: Multimodal model + + Returns: + output_tensor (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. + loss_func (callable): Loss function with a loss mask specified. + """ + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + ( + tokens, + labels, + loss_mask, + attention_mask, + position_ids, + images, + num_image_tiles, + packed_seq_params, + ) = get_batch(data_iterator, model.module.module.image_token_index, model.module.module.img_seq_len) + timers('batch-generator').stop() + + output_tensor, loss_mask = model( + images, + tokens, + position_ids, + attention_mask, + labels, + loss_mask, + num_image_tiles=num_image_tiles, + packed_seq_params=packed_seq_params, + ) + args = get_args() + if args.use_loss_scaling: + loss_function = partial(scaled_loss_func, loss_mask) + else: + loss_function = partial(loss_func, loss_mask) + + return output_tensor, loss_function + + +def llava_embedding_ranks(pp_ranks): + """LLava's embedding ranks consist of the decoder's first and last ranks (ie, the ViT has no embeddings). + Args: + pp_ranks: A list of global ranks that constitute a pipeline group. + """ + args = get_args() + + # encoder size is also the index to the first rank of the decoder. + epp = args.encoder_pipeline_model_parallel_size + + last_rank = pp_ranks[-1] + if len(pp_ranks) == 1 or pp_ranks[epp] == last_rank: + return [last_rank] + else: + return [pp_ranks[epp], last_rank] + + +def llava_position_embedding_ranks(pp_ranks): + """LLava's embedding ranks consist of the singular rank of the model or the decoder's first rank. + Args: + pp_ranks: A list of global ranks that constitute a pipeline group. + """ + args = get_args() + + # encoder size is also the index to the first rank of the decoder. + epp = args.encoder_pipeline_model_parallel_size + + last_rank = pp_ranks[-1] + if len(pp_ranks) == 1: + return [last_rank] + else: + return [pp_ranks[epp]] + + +def run_online_eval(model): + """Run an evaluation benchmark during training.""" + args = get_args() + + # Online evaluation config is not defined. Do nothing. + if not args.online_evaluation_config: + return [] + + from config import EvaluationConfig + from run_text_generation import generate_and_write_samples + + with open(args.online_evaluation_config, "r") as f: + config_dict = yaml.safe_load(f) + + config = EvaluationConfig(**config_dict) + + # The inference code assumes the first rank is the leader. + # Tensorboard writer is on the last rank. + # We must write to a storage space that all ranks see. + output_dir = os.path.join(args.save, "online_eval") + os.makedirs(output_dir, exist_ok=True) + config.output_path = os.path.join(output_dir, args.language_model_type) + + # The actual generation. + generate_and_write_samples(model[0].module, config, print_output=False) + + # Make sure the first rank is done writing so that the last rank can run eval. + torch.distributed.barrier() + + if not is_last_rank(): + return [] + + # Run evaluation. + if config.task == "TextVQA": + from evaluate_textvqa import textvqa_eval + + avg_acc = textvqa_eval(config.output_path) + + return [{"TextVQA accuracy": avg_acc}] + else: + raise NotImplementedError(f"online evaluation of {config.task} not implemented yet") + + +def write_online_eval_to_tensorboard(data, iteration, writer): + """Write online evaluation data to Tensorboard.""" + if not writer: + return + + for item in data: + for k, v in item.items(): + writer.add_scalar(k, v, iteration) + + +if __name__ == "__main__": + + train_valid_test_dataloaders_provider.is_distributed = True + + pretrain( + train_valid_test_dataloaders_provider, + model_provider, + ModelType.encoder_and_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + extra_args_provider=add_multimodal_extra_args, + process_non_loss_data_func=write_online_eval_to_tensorboard, + get_embedding_ranks=llava_embedding_ranks, + get_position_embedding_ranks=llava_position_embedding_ranks, + non_loss_data_func=run_online_eval, + ) diff --git a/examples/pretrain_bert.sh b/examples/pretrain_bert.sh deleted file mode 100755 index 3877b1a5f4..0000000000 --- a/examples/pretrain_bert.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -CHECKPOINT_PATH= -VOCAB_FILE=/bert-vocab.txt -DATA_PATH=_text_sentence - -BERT_ARGS=" - --num-layers 24 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --seq-length 512 \ - --max-position-embeddings 512 \ - --micro-batch-size 4 \ - --global-batch-size 8 \ - --lr 0.0001 \ - --train-iters 2000000 \ - --lr-decay-iters 990000 \ - --lr-decay-style linear \ - --min-lr 0.00001 \ - --weight-decay 1e-2 \ - --lr-warmup-fraction .01 \ - --clip-grad 1.0 \ - --fp16 -" - -DATA_ARGS=" - --data-path $DATA_PATH \ - --vocab-file $VOCAB_FILE \ - --split 949,50,1 -" - -OUTPUT_ARGS=" - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 -" - -torchrun pretrain_bert.py \ - $BERT_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH diff --git a/examples/pretrain_bert_distributed.sh b/examples/pretrain_bert_distributed.sh deleted file mode 100755 index 2e0209ae6b..0000000000 --- a/examples/pretrain_bert_distributed.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=localhost -MASTER_PORT=6000 -NNODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) - -CHECKPOINT_PATH= -VOCAB_FILE=/bert-vocab.txt -DATA_PATH=_text_sentence - -DISTRIBUTED_ARGS=" - --nproc_per_node $GPUS_PER_NODE \ - --nnodes $NNODES \ - --node_rank $NODE_RANK \ - --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT -" - -BERT_ARGS=" - --num-layers 24 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --seq-length 512 \ - --max-position-embeddings 512 \ - --micro-batch-size 4 \ - --global-batch-size 32 \ - --lr 0.0001 \ - --train-iters 1000000 \ - --lr-decay-iters 990000 \ - --lr-decay-style linear \ - --min-lr 1.0e-5 \ - --weight-decay 1e-2 \ - --lr-warmup-fraction .01 \ - --clip-grad 1.0 \ - --fp16 -" - -DATA_ARGS=" - --data-path $DATA_PATH \ - --vocab-file $VOCAB_FILE \ - --split 949,50,1 -" - -OUTPUT_ARGS=" - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 -" - -torchrun $DISTRIBUTED_ARGS pretrain_bert.py \ - $BERT_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --distributed-backend nccl \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH diff --git a/examples/pretrain_bert_distributed_with_mp.sh b/examples/pretrain_bert_distributed_with_mp.sh deleted file mode 100755 index 93a22c95a9..0000000000 --- a/examples/pretrain_bert_distributed_with_mp.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/bin/bash - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=localhost -MASTER_PORT=6000 -NNODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) - -CHECKPOINT_PATH= -VOCAB_FILE=/bert-vocab.txt -DATA_PATH=_text_sentence - -DISTRIBUTED_ARGS=" - --nproc_per_node $GPUS_PER_NODE \ - --nnodes $NNODES \ - --node_rank $NODE_RANK \ - --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT -" - -BERT_ARGS=" - --tensor-model-parallel-size 2 \ - --pipeline-model-parallel-size 2 \ - --num-layers 24 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --seq-length 512 \ - --max-position-embeddings 512 \ - --micro-batch-size 2 \ - --global-batch-size 16 \ - --lr 0.0001 \ - --train-iters 1000000 \ - --lr-decay-iters 990000 \ - --lr-decay-style linear \ - --min-lr 1.0e-5 \ - --weight-decay 1e-2 \ - --lr-warmup-fraction .01 \ - --clip-grad 1.0 \ - --fp16 -" - -DATA_ARGS=" - --data-path $DATA_PATH \ - --vocab-file $VOCAB_FILE \ - --split 949,50,1 -" - -OUTPUT_ARGS=" - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 -" - -torchrun $DISTRIBUTED_ARGS pretrain_bert.py \ - $BERT_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --distributed-backend nccl \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH diff --git a/examples/pretrain_gpt.sh b/examples/pretrain_gpt.sh deleted file mode 100755 index 1d4b20f004..0000000000 --- a/examples/pretrain_gpt.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash - -# Runs the "345M" parameter model - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -CHECKPOINT_PATH= -VOCAB_FILE=/gpt2-vocab.json -MERGE_FILE=/gpt2-merges.txt -DATA_PATH=_text_document - -GPT_ARGS=" - --num-layers 24 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --seq-length 1024 \ - --max-position-embeddings 1024 \ - --micro-batch-size 4 \ - --global-batch-size 8 \ - --lr 0.00015 \ - --train-iters 500000 \ - --lr-decay-iters 320000 \ - --lr-decay-style cosine \ - --min-lr 1.0e-5 \ - --weight-decay 1e-2 \ - --lr-warmup-fraction .01 \ - --clip-grad 1.0 \ - --fp16 -" - -DATA_ARGS=" - --data-path $DATA_PATH \ - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ - --split 949,50,1 -" - -OUTPUT_ARGS=" - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 -" - -torchrun pretrain_gpt.py \ - $GPT_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH diff --git a/examples/pretrain_gpt3_175B.sh b/examples/pretrain_gpt3_175B.sh deleted file mode 100755 index c26b8ee6c8..0000000000 --- a/examples/pretrain_gpt3_175B.sh +++ /dev/null @@ -1,64 +0,0 @@ -#!/bin/bash - - -#SBATCH --nodes=128 --exclusive --ntasks-per-node=8 --job-name=megatron_gpt3_175b - - -DIR=`pwd` -DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` -mkdir -p $DIR/logs - - -DATASET_1="" -DATASET_2="" -DATASET_3="" -DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}" - - -options=" \ - --tensor-model-parallel-size 8 \ - --pipeline-model-parallel-size 16 \ - --num-layers 96 \ - --hidden-size 12288 \ - --num-attention-heads 96 \ - --seq-length 2048 \ - --max-position-embeddings 2048 \ - --micro-batch-size 1 \ - --global-batch-size 1536 \ - --rampup-batch-size 16 16 5859375 \ - --train-samples 146484375 \ - --lr-decay-samples 126953125 \ - --lr-warmup-samples 183105 \ - --lr 6.0e-5 \ - --min-lr 6.0e-6 \ - --lr-decay-style cosine \ - --log-interval 10 \ - --eval-iters 40 \ - --eval-interval 1000 \ - --data-path ${DATASET} \ - --vocab-file \ - --merge-file \ - --save-interval 1000 \ - --save \ - --load \ - --split 98,2,0 \ - --clip-grad 1.0 \ - --weight-decay 0.1 \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --init-method-std 0.006 \ - --tensorboard-dir \ - --fp16 " - - -run_cmd="python -u ${DIR}/pretrain_gpt.py $@ ${options}" - - -srun -l \ - --container-image "nvcr.io/nvidia/pytorch:20.12-py3" \ - --container-mounts "" \ - --output=$DIR/logs/%x_%j_$DATETIME.log sh -c "${run_cmd}" - - -set +x - diff --git a/examples/pretrain_gpt_distributed.sh b/examples/pretrain_gpt_distributed.sh deleted file mode 100755 index effce206d3..0000000000 --- a/examples/pretrain_gpt_distributed.sh +++ /dev/null @@ -1,67 +0,0 @@ -#!/bin/bash - -# Runs the "345M" parameter model - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=localhost -MASTER_PORT=6000 -NNODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) - -CHECKPOINT_PATH= -VOCAB_FILE=/gpt2-vocab.json -MERGE_FILE=/gpt2-merges.txt -DATA_PATH=_text_document - -DISTRIBUTED_ARGS=" - --nproc_per_node $GPUS_PER_NODE \ - --nnodes $NNODES \ - --node_rank $NODE_RANK \ - --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT -" - -GPT_ARGS=" - --num-layers 24 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --seq-length 1024 \ - --max-position-embeddings 1024 \ - --micro-batch-size 8 \ - --global-batch-size 64 \ - --lr 0.00015 \ - --train-iters 500000 \ - --lr-decay-iters 320000 \ - --lr-decay-style cosine \ - --min-lr 1.0e-5 \ - --weight-decay 1e-2 \ - --lr-warmup-fraction .01 \ - --clip-grad 1.0 \ - --fp16 -" - -DATA_ARGS=" - --data-path $DATA_PATH \ - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ - --split 949,50,1 -" - -OUTPUT_ARGS=" - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 -" - -torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ - $GPT_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --distributed-backend nccl \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH diff --git a/examples/pretrain_gpt_distributed_with_mp.sh b/examples/pretrain_gpt_distributed_with_mp.sh deleted file mode 100755 index 470a2560d3..0000000000 --- a/examples/pretrain_gpt_distributed_with_mp.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash - -# Runs the "345M" parameter model - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=localhost -MASTER_PORT=6000 -NNODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) - -CHECKPOINT_PATH= -VOCAB_FILE=/gpt2-vocab.json -MERGE_FILE=/gpt2-merges.txt -DATA_PATH=_text_document - -DISTRIBUTED_ARGS=" - --nproc_per_node $GPUS_PER_NODE \ - --nnodes $NNODES \ - --node_rank $NODE_RANK \ - --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT -" - -GPT_ARGS=" - --tensor-model-parallel-size 2 \ - --pipeline-model-parallel-size 2 \ - --sequence-parallel \ - --num-layers 24 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --seq-length 1024 \ - --max-position-embeddings 1024 \ - --micro-batch-size 4 \ - --global-batch-size 16 \ - --lr 0.00015 \ - --train-iters 500000 \ - --lr-decay-iters 320000 \ - --lr-decay-style cosine \ - --min-lr 1.0e-5 \ - --weight-decay 1e-2 \ - --lr-warmup-fraction .01 \ - --clip-grad 1.0 \ - --fp16 -" - -DATA_ARGS=" - --data-path $DATA_PATH \ - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ - --split 949,50,1 -" - -OUTPUT_ARGS=" - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 -" - -torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ - $GPT_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --distributed-backend nccl \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH - diff --git a/examples/pretrain_ict.sh b/examples/pretrain_ict.sh deleted file mode 100755 index 8cba0f08ba..0000000000 --- a/examples/pretrain_ict.sh +++ /dev/null @@ -1,44 +0,0 @@ -#! /bin/bash - -# Runs the "217M" parameter biencoder model for ICT retriever - -RANK=0 -WORLD_SIZE=1 - -PRETRAINED_BERT_PATH= -TEXT_DATA_PATH= -TITLE_DATA_PATH= -CHECKPOINT_PATH= - - -python pretrain_ict.py \ - --num-layers 12 \ - --hidden-size 768 \ - --num-attention-heads 12 \ - --tensor-model-parallel-size 1 \ - --micro-batch-size 32 \ - --seq-length 256 \ - --max-position-embeddings 512 \ - --train-iters 100000 \ - --vocab-file bert-vocab.txt \ - --tokenizer-type BertWordPieceLowerCase \ - --DDP-impl torch \ - --bert-load ${PRETRAINED_BERT_PATH} \ - --log-interval 100 \ - --eval-interval 1000 \ - --eval-iters 10 \ - --retriever-report-topk-accuracies 1 5 10 20 100 \ - --retriever-score-scaling \ - --load $CHECKPOINT_PATH \ - --save $CHECKPOINT_PATH \ - --data-path ${TEXT_DATA_PATH} \ - --titles-data-path ${TITLE_DATA_PATH} \ - --lr 0.0001 \ - --lr-decay-style linear \ - --weight-decay 1e-2 \ - --clip-grad 1.0 \ - --lr-warmup-fraction 0.01 \ - --save-interval 4000 \ - --exit-interval 8000 \ - --query-in-block-prob 0.1 \ - --fp16 diff --git a/examples/pretrain_t5.sh b/examples/pretrain_t5.sh deleted file mode 100644 index c44cc5763c..0000000000 --- a/examples/pretrain_t5.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -CHECKPOINT_PATH= -VOCAB_FILE=/t5-vocab.txt -DATA_PATH=_text_sentence - -T5_ARGS=" - --num-layers 12 \ - --hidden-size 768 \ - --num-attention-heads 12 \ - --kv-channels 64 \ - --ffn-hidden-size 3072 \ - --encoder-seq-length 512 \ - --decoder-seq-length 128 \ - --max-position-embeddings 512 \ - --micro-batch-size 16 \ - --global-batch-size 16 \ - --lr 0.0001 \ - --train-iters 1000000 \ - --lr-decay-iters 1000000 \ - --lr-decay-style linear \ - --min-lr 0.00001 \ - --weight-decay 1e-2 \ - --lr-warmup-fraction .01 \ - --clip-grad 1.0 \ - --fp16 \ - --vocab-extra-ids 100 -" - -DATA_ARGS=" - --data-path $DATA_PATH \ - --vocab-file $VOCAB_FILE \ - --split 949,50,1 -" - -OUTPUT_ARGS=" - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 -" - -torchrun pretrain_t5.py \ - $T5_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH diff --git a/examples/pretrain_t5_distributed_with_mp.sh b/examples/pretrain_t5_distributed_with_mp.sh deleted file mode 100644 index 9802866263..0000000000 --- a/examples/pretrain_t5_distributed_with_mp.sh +++ /dev/null @@ -1,68 +0,0 @@ -#!/bin/bash - -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=localhost -MASTER_PORT=6000 -NNODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) - -CHECKPOINT_PATH= -VOCAB_FILE=/t5-vocab.txt -DATA_PATH=_text_sentence - -DISTRIBUTED_ARGS=" - --nproc_per_node $GPUS_PER_NODE \ - --nnodes $NNODES \ - --node_rank $NODE_RANK \ - --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT -" - -T5_ARGS=" - --tensor-model-parallel-size 2 \ - --num-layers 12 \ - --hidden-size 768 \ - --num-attention-heads 12 \ - --kv-channels 64 \ - --ffn-hidden-size 3072 \ - --encoder-seq-length 512 \ - --decoder-seq-length 128 \ - --max-position-embeddings 512 \ - --micro-batch-size 16 \ - --global-batch-size 128 \ - --lr 0.0001 \ - --train-iters 1000000 \ - --lr-decay-iters 1000000 \ - --lr-decay-style linear \ - --min-lr 0.00001 \ - --weight-decay 1e-2 \ - --lr-warmup-fraction .01 \ - --clip-grad 1.0 \ - --fp16 \ - --vocab-extra-ids 100 -" - -DATA_ARGS=" - --data-path $DATA_PATH \ - --vocab-file $VOCAB_FILE \ - --split 949,50,1 -" - -OUTPUT_ARGS=" - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 -" - -torchrun $DISTRIBUTED_ARGS pretrain_t5.py \ - $T5_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --distributed-backend nccl \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH diff --git a/examples/pretrain_vision_classify.sh b/examples/pretrain_vision_classify.sh deleted file mode 100755 index 5fcdd6e6ef..0000000000 --- a/examples/pretrain_vision_classify.sh +++ /dev/null @@ -1,64 +0,0 @@ -#! /bin/bash - -# Pre-trains ViT based image classificaation model - -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export NCCL_IB_SL=1 - -# Training and validation paths should each point to a folder where each -# sub-folder contains a collection of images in jpg or png format -# e.g. If using imagenet, one train image might be, train_data/n01688243/n01688243_11301.JPEG -DATA_PATH_TRAIN= -DATA_PATH_VAL= - -CHECKPOINT_PATH= - -CLASSIFIER_ARGS=" - --tensor-model-parallel-size 1 \ - --num-layers 12 \ - --hidden-size 768 \ - --num-attention-heads 12 \ - --patch-dim 4 \ - --seq-length 3136 \ - --max-position-embeddings 3136 \ - --img-h 224 \ - --img-w 224 \ - --mask-factor 1.0 \ - --fp16 \ - --train-iters 750000 \ - --lr-decay-style cosine \ - --micro-batch-size 4 \ - --global-batch-size 1024 \ - --lr 0.0005 \ - --min-lr 0.00001 \ - --attention-dropout 0.0 \ - --weight-decay 0.05 \ - --lr-warmup-iters 12500 \ - --clip-grad 1.0 \ - --no-gradient-accumulation-fusion \ - --num-workers 4 \ - --DDP-impl torch " - -DATA_ARGS=" - --tokenizer-type NullTokenizer \ - --vocab-size 0 \ - --data-path $DATA_PATH_TRAIN $DATA_PATH_VAL \ - --no-data-sharding \ - --split 949,50,1 \ -" - -OUTPUT_ARG=" - --log-interval 32 \ - --save-interval 10000 \ - --eval-interval 2500 \ - --eval-iters 100 \ - --tensorboard-dir ${CHECKPOINT_PATH} \ -" - -torchrun pretrain_vision_classification.py \ - $CLASSIFIER_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH - diff --git a/examples/pretrain_vision_dino.sh b/examples/pretrain_vision_dino.sh deleted file mode 100755 index b047e4e340..0000000000 --- a/examples/pretrain_vision_dino.sh +++ /dev/null @@ -1,67 +0,0 @@ -#! /bin/bash - -# Pre-trains Dino V1 model -# For model details: https://arxiv.org/abs/2104.14294 -# For original author implementation: https://github.com/facebookresearch/dino/tree/main - -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export NCCL_IB_SL=1 - -# Training and validation paths should each point to a folder where each -# sub-folder contains a collection of images in jpg or png format -# e.g. If using imagenet, one train image might be, train_data/n01688243/n01688243_11301.JPEG -DATA_PATH_TRAIN= -DATA_PATH_VAL= - -CHECKPOINT_PATH= - -DINO_ARGS=" - --vision-pretraining-type dino \ - --tensor-model-parallel-size 1 \ - --num-layers 12 \ - --hidden-size 768 \ - --num-attention-heads 12 \ - --patch-dim 4 \ - --seq-length 3136 \ - --max-position-embeddings 3136 \ - --img-h 224 \ - --img-w 224 \ - --mask-factor 1.0 \ - --fp16 \ - --train-iters 750000 \ - --lr-decay-style cosine \ - --micro-batch-size 4 \ - --global-batch-size 1024 \ - --lr 0.0005 \ - --min-lr 0.00001 \ - --attention-dropout 0.0 \ - --weight-decay 0.05 \ - --lr-warmup-iters 12500 \ - --clip-grad 1.0 \ - --no-gradient-accumulation-fusion \ - --num-workers 4 \ - --DDP-impl torch " - -DATA_ARGS=" - --tokenizer-type NullTokenizer \ - --vocab-size 0 \ - --data-path $DATA_PATH_TRAIN $DATA_PATH_VAL \ - --no-data-sharding \ - --split 949,50,1 \ -" - -OUTPUT_ARG=" - --log-interval 32 \ - --save-interval 10000 \ - --eval-interval 2500 \ - --eval-iters 100 \ - --tensorboard-dir ${CHECKPOINT_PATH} \ -" - -torchrun pretrain_vision_dino.py \ - $DINO_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH - diff --git a/examples/pretrain_vision_inpaint.sh b/examples/pretrain_vision_inpaint.sh deleted file mode 100755 index 01c7e71a9e..0000000000 --- a/examples/pretrain_vision_inpaint.sh +++ /dev/null @@ -1,65 +0,0 @@ -#! /bin/bash - -# Pre-trains ViT based image inpainting model - -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export NCCL_IB_SL=1 - -# Training and validation paths should each point to a folder where each -# sub-folder contains a collection of images in jpg or png format -# e.g. If using imagenet, one train image might be, train_data/n01688243/n01688243_11301.JPEG -DATA_PATH_TRAIN= -DATA_PATH_VAL= - -CHECKPOINT_PATH= - -INPAINT_ARGS=" - --vision-pretraining-type inpaint \ - --tensor-model-parallel-size 1 \ - --num-layers 12 \ - --hidden-size 768 \ - --num-attention-heads 12 \ - --patch-dim 4 \ - --seq-length 3136 \ - --max-position-embeddings 3136 \ - --img-h 224 \ - --img-w 224 \ - --mask-factor 1.0 \ - --fp16 \ - --train-iters 750000 \ - --lr-decay-style cosine \ - --micro-batch-size 4 \ - --global-batch-size 1024 \ - --lr 0.0005 \ - --min-lr 0.00001 \ - --attention-dropout 0.0 \ - --weight-decay 0.05 \ - --lr-warmup-iters 12500 \ - --clip-grad 1.0 \ - --no-gradient-accumulation-fusion \ - --num-workers 4 \ - --DDP-impl torch " - -DATA_ARGS=" - --tokenizer-type NullTokenizer \ - --vocab-size 0 \ - --data-path $DATA_PATH_TRAIN $DATA_PATH_VAL \ - --no-data-sharding \ - --split 949,50,1 \ -" - -OUTPUT_ARG=" - --log-interval 32 \ - --save-interval 10000 \ - --eval-interval 2500 \ - --eval-iters 100 \ - --tensorboard-dir ${CHECKPOINT_PATH} \ -" - -torchrun pretrain_vision_inpaint.py \ - $INPAINT_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH - diff --git a/examples/retro/README.md b/examples/retro/README.md new file mode 100644 index 0000000000..f78bcdeb56 --- /dev/null +++ b/examples/retro/README.md @@ -0,0 +1,74 @@ +# RETRO MODEL + +## Table of contents +- [1. Training Setup](#1-training-setup) +- [2. Data Preprocessing](#2-data-preprocessing) +- [3. Configurations](#3-configurations) + +## 1. Training setup + + +To run the model using a docker container run it as follows +``` +PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.09-py3 +CHECKPOINT_PATH="" # +TENSORBOARD_LOGS_PATH=""# + +docker run \ + --gpus=all \ + --ipc=host \ + --workdir /workspace/megatron-lm \ + -v /path/to/data:/path/to/data \ + -v /path/to/megatron-lm:/workspace/megatron-lm \ + megatron-lm nvcr.io/nvidia/pytorch:23.09-py3 \ + bash examples/retro/train_retro_2b_distributed.sh $CHECKPOINT_PATH $TENSORBOARD_LOGS_PATH" + +``` +NOTE: Depending on the environment you are running it the above command might look slightly different. + +NOTE: Due to how Retro preprocess and caches elements of the pretraining dataset before training begins, some arguments are auto-loaded from the Retro preprocessing configuration. These loaded arguments include: + +- `--data-path` +- `--data-cache-path` +- `--eval-interval` +- `--eval-iters` +- `--global-batch-size` +- `--tokenizer-type` +- `--tokenizer-model` +- `--vocab-file` +- `--merge-file` +- `--seed` +- `--seq-length` +- `--train-samples` + + +## 2. Data Preprocessing + + +Retro preprocesses and caches data prior to pretraining, to greatly speed up pretraining. During data preprocessing, the retrieval database is built, and neighbor IDs are queried for each sample within the pretraining dataset. Please see `preprocess_data.sh` for an example script to preprocess data for Retro. The reference documentation for data preprocessing can be found [here](tools/retro/README.md). + + +## 3. Configurations + +The example in this folder shows you how to run a 2B model. Below are a few other example configurations. + +### 857M +``` + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --seq-length 2048 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + +``` + +### 4B +``` + --num-layers 48 \ + --hidden-size 2560 \ + --num-attention-heads 32 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + +``` diff --git a/tools/retro/examples/preprocess_data.sh b/examples/retro/preprocess_data.sh similarity index 71% rename from tools/retro/examples/preprocess_data.sh rename to examples/retro/preprocess_data.sh index e60a718615..5d2e66ba0e 100644 --- a/tools/retro/examples/preprocess_data.sh +++ b/examples/retro/preprocess_data.sh @@ -7,23 +7,31 @@ unset NCCL_DEBUG ######## Megatron, Retro dirs. ######## REPO_DIR="" -RETRO_WORKDIR="" +RETRO_PROJECT_DIR="" ######## Task (e.g., db, index, query). ######## -RETRO_TASKS="db-build" -# RETRO_TASKS="index-train" -# RETRO_TASKS="index-add" -# RETRO_TASKS="query-pretraining-neighbors" +# This script takes a single argument, which specifies the retro task to be +# performed. The available tasks are: db-build, index-train, index-add, and +# query-neighbors. -######## Data. ######## +# ~~ Examples ~~ +# RETRO_TASKS="db-build" # Build the retrieval database +# RETRO_TASKS="index-train" # Train the index +# RETRO_TASKS="index-add" # Add data to the index +# RETRO_TASKS="query-neighbors" # Perform query pretraining for neighbors + +# You can also provide the task as a command-line argument when executing the +# script. Example: ./preprocess_data.sh index-add +RETRO_TASKS=$1 +######## Data. ######## DATA_BLEND="" ######## Index. ######## RETRO_INDEX_STR="OPQ32_64,IVF65536_HNSW8,PQ32" -RETRO_INDEX_NTRAIN=1000000 +RETRO_INDEX_NTRAIN=66625331 RETRO_INDEX_TRAIN_LOAD_FRACTION=0.97 RETRO_INDEX_ADD_LOAD_FRACTION=0.95 @@ -32,19 +40,19 @@ RETRO_INDEX_ADD_LOAD_FRACTION=0.95 RETRO_GPT_SEED=1234 RETRO_GPT_SPLIT="98,2,0" RETRO_GPT_DATA_PATH=${DATA_BLEND} -RETRO_GPT_DATALOADER_TYPE=single +RETRO_GPT_TRAIN_SAMPLES=200000 RETRO_GPT_EVAL_INTERVAL=2000 RETRO_GPT_EVAL_ITERS=50 -RETRO_GPT_TRAIN_SAMPLES=200000 RETRO_GPT_LR_DECAY_SAMPLES=175000 RETRO_GPT_LR_WARMUP_SAMPLES=10000 -RETRO_GPT_SEQ_LENGTH=512 +RETRO_GPT_SEQ_LENGTH=2048 RETRO_GPT_GLOBAL_BATCH_SIZE=256 RETRO_GPT_CHUNK_LENGTH=64 ######## Query. ######## -RETRO_QUERY_NUM_NEIGHBORS_QUERY=200 RETRO_QUERY_NUM_NEIGHBORS_SAVE=20 +RETRO_QUERY_NUM_NEIGHBORS_QUERY=200 +RETRO_QUERY_NUM_NEIGHBORS_SAVE=20 RETRO_QUERY_EF_SEARCH=32 RETRO_QUERY_NPROBE=4096 @@ -61,12 +69,12 @@ ARGS=" \ --global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \ --seq-length 512 \ --max-position-embeddings 512 \ - --load \ + --load ${RETRO_PROJECT_DIR}/checkpoints/bert \ --exit-on-missing-checkpoint \ --no-load-optim \ - --data-path ${RETRO_GPT_DATA_PATH} \ + --data-path [null] \ --tokenizer-type BertWordPieceLowerCase \ - --vocab-file \ + --vocab-file ${RETRO_PROJECT_DIR}/tokenizer/bert-large-uncased-vocab.txt \ --split ${RETRO_GPT_SPLIT} \ --distributed-backend nccl \ --lr 0.0001 \ @@ -79,23 +87,21 @@ ARGS=" \ --clip-grad 1.0 \ --eval-interval ${RETRO_GPT_EVAL_INTERVAL} \ --eval-iters ${RETRO_GPT_EVAL_ITERS} \ - --fp16 \ - --DDP-impl local \ - --dataloader-type ${RETRO_GPT_DATALOADER_TYPE} \ + --bf16 \ --no-data-sharding \ --no-gradient-accumulation-fusion \ --no-async-tensor-model-parallel-allreduce \ --bert-embedder-type megatron \ --output-bert-embeddings \ \ - --retro-workdir ${RETRO_WORKDIR} \ + --retro-project-dir ${RETRO_PROJECT_DIR} \ --retro-tasks ${RETRO_TASKS} \ - --retro-return-doc-ids \ - --retro-bert-vocab-file \ + --retro-bert-vocab-file tokenizer/bert-large-uncased-vocab.txt \ --retro-bert-tokenizer-type BertWordPieceLowerCase \ + \ --retro-gpt-seed ${RETRO_GPT_SEED} \ --retro-gpt-tokenizer-type GPTSentencePieceTokenizer \ - --retro-gpt-tokenizer-model \ + --retro-gpt-tokenizer-model /path/to/tokenizer/model \ --retro-gpt-seq-length ${RETRO_GPT_SEQ_LENGTH} \ --retro-gpt-chunk-length ${RETRO_GPT_CHUNK_LENGTH} \ --retro-gpt-global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \ @@ -103,12 +109,15 @@ ARGS=" \ --retro-gpt-eval-iters ${RETRO_GPT_EVAL_ITERS} \ --retro-gpt-split ${RETRO_GPT_SPLIT} \ --retro-gpt-data-path ${RETRO_GPT_DATA_PATH} \ + --retro-gpt-train-samples ${RETRO_GPT_TRAIN_SAMPLES} \ + \ --retro-index-str ${RETRO_INDEX_STR} \ --retro-index-ntrain ${RETRO_INDEX_NTRAIN} \ --retro-index-train-load-fraction ${RETRO_INDEX_TRAIN_LOAD_FRACTION} \ --retro-index-add-load-fraction ${RETRO_INDEX_ADD_LOAD_FRACTION} \ - --retro-index-no-delete-training-embeddings \ - --retro-index-no-delete-added-codes \ + --no-retro-index-delete-training-embeddings \ + --no-retro-index-delete-added-codes \ + \ --retro-query-num-neighbors-query ${RETRO_QUERY_NUM_NEIGHBORS_QUERY} \ --retro-query-num-neighbors-save ${RETRO_QUERY_NUM_NEIGHBORS_SAVE} \ --retro-query-ef-search ${RETRO_QUERY_EF_SEARCH} \ @@ -127,7 +136,7 @@ CMD="\ --node_rank ${NODE_RANK} \ --master_addr ${MASTER_ADDR} \ --master_port 6000 \ - tools/retro/main.py ${ARGS} \ + tools/retro/preprocess_data.py ${ARGS} \ " echo "~~~~~~~~~~~~~~~~~~~~~~~~~~" echo "CMD = '$CMD'." diff --git a/examples/retro/train_retro_2b_distributed.sh b/examples/retro/train_retro_2b_distributed.sh new file mode 100644 index 0000000000..c8276b56f4 --- /dev/null +++ b/examples/retro/train_retro_2b_distributed.sh @@ -0,0 +1,98 @@ +#!/bin/bash + +# Runs the "307M" parameter Retro model. + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NUM_NODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) + +CHECKPOINT_PATH=$1 # +TENSORBOARD_LOGS_PATH=$2 # + +DISTRIBUTED_ARGS=( + --nproc_per_node $GPUS_PER_NODE + --nnodes $NUM_NODES + --master_addr $MASTER_ADDR + --master_port $MASTER_PORT +) + +######## GPT or Retro? ######## + +# 0 : GPT. +# 1 : Retro + +ADD_RETRIEVER=1 + +######## Megatron, Retro dirs. ######## + +RETRO_PROJECT_DIR="" + +######## Model, training args. ######## + +# ** Note: --seq-length auto loaded from Retro project dir. +RETRO_MODEL_ARGS=( + --num-layers 32 + --hidden-size 2048 + --num-attention-heads 32 +) + +# ** Note: --data-path, --tokenizer-type, and --tokenizer-model auto loaded from Retro project dir. +DATA_ARGS=( + --split 98,2,0 +) + +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size 8 + --pipeline-model-parallel-size 1 +) + +# ** Note: --eval-interval, --eval-iters auto loaded from Retro project dir. +EVAL_AND_LOGGING_ARGS=( + --log-interval 100 + --save-interval 10000 + --eval-interval 1000 + --save $CHECKPOINT_PATH + --load $CHECKPOINT_PATH + --eval-iters 10 + --tensorboard-dir $TENSORBOARD_LOGS_PATH +) + +TRAINING_ARGS=" \ + --retro-project-dir ${RETRO_PROJECT_DIR} \ + --transformer-impl transformer_engine \ + --num-workers 8 \ + --micro-batch-size 4 \ + --lr-decay-samples 166400000 \ + --lr-warmup-samples 162761 \ + --lr 6.0e-4 \ + --min-lr 6.0e-5 \ + --lr-decay-style cosine \ + --clip-grad 1.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.023 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --bf16 \ + --no-data-sharding \ +" + +if [ "$ADD_RETRIEVER" = "1" ]; then + TRAINING_ARGS+=" --retro-add-retriever" +fi + +######## Command. ######## + +torchrun ${DISTRIBUTED_ARGS[@]} pretrain_retro.py \ + ${RETRO_MODEL_ARGS[@]} \ + ${TRAINING_ARGS} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${EVAL_AND_LOGGING_ARGS[@]} diff --git a/examples/run_simple_mcore_train_loop.py b/examples/run_simple_mcore_train_loop.py new file mode 100644 index 0000000000..d5ffffeeaf --- /dev/null +++ b/examples/run_simple_mcore_train_loop.py @@ -0,0 +1,158 @@ +import os +import torch +from torch.optim import Adam +from torch.utils.data import DataLoader +from functools import partial +from pathlib import Path + +from megatron.core import parallel_state +from megatron.core import dist_checkpointing +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec +from megatron.core.datasets.utils import compile_helpers +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset +from megatron.training.tokenizer.tokenizer import _NullTokenizer + + +_SEQUENCE_LENGTH = 64 + + +def initialize_distributed(tensor_model_parallel_size=1, pipeline_model_parallel_size=1): + parallel_state.destroy_model_parallel() + + # Torch setup for distributed training + rank = int(os.environ['LOCAL_RANK']) + world_size = torch.cuda.device_count() + torch.cuda.set_device(rank) + torch.distributed.init_process_group(world_size=world_size, rank=rank) + + # Megatron core distributed training initialization + parallel_state.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size) + +def model_provider(): + """Build the model.""" + + transformer_config = TransformerConfig( + num_layers=2, + hidden_size=12, + num_attention_heads=4, + use_cpu_initialization=True, + pipeline_dtype=torch.float32, + ) + + gpt_model = GPTModel( + config=transformer_config, + transformer_layer_spec=get_gpt_layer_local_spec(), + vocab_size=100, + max_sequence_length=_SEQUENCE_LENGTH, + ) + + return gpt_model + +def get_train_data_iterator(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + compile_helpers() + torch.distributed.barrier() + else: + compile_helpers() + + config = GPTDatasetConfig( + random_seed=0, + sequence_length=_SEQUENCE_LENGTH, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + tokenizer=_NullTokenizer(vocab_size=_SEQUENCE_LENGTH), + ) + + datasets = BlendedMegatronDatasetBuilder( + MockGPTDataset, [1000, None, None], lambda: True, config + ).build() + + train_dataloader = DataLoader(datasets[0], batch_size=8, shuffle=True) + + train_iterator = iter(train_dataloader) + + return train_iterator + +def forward_step_func(data_iterator, model): + + def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + # If you have data parallel reduce loss across data parallel groups. + # If pipeline parallel, loss computation is done only in last stage. + + return loss, {'lm loss': loss} + + data = next(data_iterator) + tokens = data['tokens'].to(device) + attention_mask = data['attention_mask'].to(device) + position_ids = data['position_ids'].to(device) + labels = data['labels'].to(device) + loss_mask = data['loss_mask'].to(device) + + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + +def save_distributed_checkpoint(checkpoint_path, gpt_model): + sharded_state_dict = gpt_model.sharded_state_dict(prefix='') + dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path) + +def load_distributed_checkpoint(checkpoint_path, gpt_model): + sharded_state_dict=gpt_model.sharded_state_dict(prefix='') + checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path) + gpt_model.load_state_dict(checkpoint) + return gpt_model + +if __name__ == "__main__": + initialize_distributed(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) + model_parallel_cuda_manual_seed(123) + + gpt_model = model_provider() + device = torch.device("cuda") + gpt_model.to(device) + + optim = Adam(gpt_model.parameters()) + + train_iterator = get_train_data_iterator() + + forward_backward_func = get_forward_backward_func() + + # Running the model for 5 iterations + for _ in range(5): + optim.zero_grad() + + losses_reduced = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=train_iterator, + model=gpt_model, + num_microbatches=1, + seq_length=_SEQUENCE_LENGTH, + micro_batch_size=8, + decoder_seq_length=_SEQUENCE_LENGTH, + forward_only=False) + + optim.step() + + print(f'Losses reduced : {losses_reduced}') + + # Saving the model + ckpt_path = os.getcwd() + '/ckpt' + Path(ckpt_path).mkdir(exist_ok=True) + save_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path=ckpt_path) + + # Loading the model + gpt_model = load_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path=ckpt_path) + gpt_model.to(device) + print('Successfully loaded the model') + diff --git a/examples/t5/README.md b/examples/t5/README.md new file mode 100644 index 0000000000..205da1db37 --- /dev/null +++ b/examples/t5/README.md @@ -0,0 +1,55 @@ +# T5 MODEL + +## Table of contents +- [1. Training Setup](#1-training-setup) +- [2. Configurations](#2-configurations) +- [3. Training Results](#3-training-results) + +## 1. Training setup + +To run the model on a Slurm based cluster +``` +PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.09-py3 +ACCOUNT_NAME="" +PARTITION="" +JOB_NAME="" +NUM_NODES=1 +CHECKPOINT_PATH="" # +TENSORBOARD_LOGS_PATH=""# +VOCAB_FILE="" #/bert-large-cased-vocab.txt +DATA_PATH="" #_text_document + +srun -N $NUM_NODES --container-image $PYTORCH_IMAGE --container-mounts "/path/to/data:/path/to/data,/path/to/megatron-lm:/workspace/megatron-lm" --account $ACCOUNT -N 1 -J $JOB_NAME -p $PARTITION --no-container-mount-home -c " + cd /workspace/megatron-lm + ./examples/t5/train_t5_220m_distributed.sh $CHECKPOINT_PATH $TENSORBOARD_LOGS_PATH $VOCAB_FILE $DATA_PATH" + +``` + +## 2. Configurations + +The architecture arguments below shows configuration for T5 220M model. + +### 220M +``` + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --kv-channels 64 \ + --ffn-hidden-size 3072 \ + --encoder-seq-length 512 \ + --decoder-seq-length 128 \ + --max-position-embeddings 512 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + +``` + + +## 3. Training Results + +Below is the training curve for the 220M model on Pile dataset. The training takes 4 days on 32 GPUs, with batch size of 2048. + +Finetuning on SQUAD dataset, the validation result is: 63.44\% +

+ +

diff --git a/examples/t5/t5_mcore_train_curve.png b/examples/t5/t5_mcore_train_curve.png new file mode 100644 index 0000000000..de1aaa8582 Binary files /dev/null and b/examples/t5/t5_mcore_train_curve.png differ diff --git a/examples/pretrain_t5_distributed.sh b/examples/t5/train_t5_220m_distributed.sh old mode 100644 new mode 100755 similarity index 56% rename from examples/pretrain_t5_distributed.sh rename to examples/t5/train_t5_220m_distributed.sh index 42698e01af..62e6f9db4b --- a/examples/pretrain_t5_distributed.sh +++ b/examples/t5/train_t5_220m_distributed.sh @@ -1,29 +1,33 @@ #!/bin/bash +# Runs the "220M" parameter model + export CUDA_DEVICE_MAX_CONNECTIONS=1 GPUS_PER_NODE=8 # Change for multinode config MASTER_ADDR=localhost MASTER_PORT=6000 -NNODES=1 +NUM_NODES=1 NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) +WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) -CHECKPOINT_PATH= -VOCAB_FILE=/t5-vocab.txt -DATA_PATH=_text_sentence +CHECKPOINT_PATH=$1 # +TENSORBOARD_DIR=$2 # +VOCAB_FILE=$3 #/bert-large-cased-vocab.txt +DATA_PATH=$4 #_text_document DISTRIBUTED_ARGS=" --nproc_per_node $GPUS_PER_NODE \ - --nnodes $NNODES \ + --nnodes $NUM_NODES \ --node_rank $NODE_RANK \ --master_addr $MASTER_ADDR \ --master_port $MASTER_PORT " T5_ARGS=" - --num-layers 12 \ + --encoder-num-layers 12 \ + --decoder-num-layers 12 \ --hidden-size 768 \ --num-attention-heads 12 \ --kv-channels 64 \ @@ -31,8 +35,8 @@ T5_ARGS=" --encoder-seq-length 512 \ --decoder-seq-length 128 \ --max-position-embeddings 512 \ - --micro-batch-size 16 \ - --global-batch-size 128 \ + --micro-batch-size 64 \ + --global-batch-size 512 \ --lr 0.0001 \ --train-iters 1000000 \ --lr-decay-iters 1000000 \ @@ -41,19 +45,26 @@ T5_ARGS=" --weight-decay 1e-2 \ --lr-warmup-fraction .01 \ --clip-grad 1.0 \ - --fp16 \ - --vocab-extra-ids 100 + --bf16 \ + --vocab-extra-ids 100 \ + --init-method-std 0.015 \ + --transformer-impl transformer_engine \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --attention-backend auto \ " DATA_ARGS=" --data-path $DATA_PATH \ --vocab-file $VOCAB_FILE \ - --split 949,50,1 + --tokenizer-type BertWordPieceCase \ + --split 99982,9,9 \ " OUTPUT_ARGS=" --log-interval 100 \ - --save-interval 10000 \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --save-interval 500 \ --eval-interval 1000 \ --eval-iters 10 " @@ -64,4 +75,4 @@ torchrun $DISTRIBUTED_ARGS pretrain_t5.py \ $OUTPUT_ARGS \ --distributed-backend nccl \ --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH + --load $CHECKPOINT_PATH \ diff --git a/images/Achieved_petaFLOPs.png b/images/Achieved_petaFLOPs.png deleted file mode 100644 index 3431099f3f..0000000000 Binary files a/images/Achieved_petaFLOPs.png and /dev/null differ diff --git a/images/cases_april2021.png b/images/cases_april2021.png deleted file mode 100644 index 8a6d9e9f8b..0000000000 Binary files a/images/cases_april2021.png and /dev/null differ diff --git a/images/model_table.png b/images/model_table.png new file mode 100644 index 0000000000..f126c2fcfb Binary files /dev/null and b/images/model_table.png differ diff --git a/images/strong_scaling.png b/images/strong_scaling.png new file mode 100644 index 0000000000..d8337c347e Binary files /dev/null and b/images/strong_scaling.png differ diff --git a/images/weak_scaling.png b/images/weak_scaling.png new file mode 100644 index 0000000000..59c3cec6c6 Binary files /dev/null and b/images/weak_scaling.png differ diff --git a/megatron/arguments.py b/megatron/arguments.py deleted file mode 100644 index 9192e12c7a..0000000000 --- a/megatron/arguments.py +++ /dev/null @@ -1,1366 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Megatron arguments.""" - -import argparse -import dataclasses -import json -import os -import torch -import types - -import torch.nn.functional as F -from megatron.global_vars import set_retro_args, get_retro_args -from tools.retro.utils import get_args_path as get_retro_args_path - -from megatron.core.transformer import TransformerConfig - - -def parse_args(extra_args_provider=None, ignore_unknown_args=False): - """Parse all arguments.""" - parser = argparse.ArgumentParser(description='Megatron-LM Arguments', - allow_abbrev=False) - - # Standard arguments. - parser = _add_network_size_args(parser) - parser = _add_regularization_args(parser) - parser = _add_training_args(parser) - parser = _add_initialization_args(parser) - parser = _add_learning_rate_args(parser) - parser = _add_checkpointing_args(parser) - parser = _add_mixed_precision_args(parser) - parser = _add_distributed_args(parser) - parser = _add_validation_args(parser) - parser = _add_data_args(parser) - parser = _add_autoresume_args(parser) - parser = _add_biencoder_args(parser) - parser = _add_vision_args(parser) - parser = _add_logging_args(parser) - parser = _add_inference_args(parser) - parser = _add_transformer_engine_args(parser) - parser = _add_retro_args(parser) - parser = _add_experimental_args(parser) - - # Custom arguments. - if extra_args_provider is not None: - parser = extra_args_provider(parser) - - # Parse. - if ignore_unknown_args: - args, _ = parser.parse_known_args() - else: - args = parser.parse_args() - - # Args from environment - args.rank = int(os.getenv('RANK', '0')) - args.world_size = int(os.getenv("WORLD_SIZE", '1')) - - return args - -def validate_args(args, defaults={}): - # Tensor model parallel size. - args.tensor_model_parallel_size = min( - args.tensor_model_parallel_size, args.world_size) - assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\ - ' ({}) is not divisible by tensor model parallel size ({})'.format( - args.world_size, args.tensor_model_parallel_size) - # Pipeline model parallel size. - args.pipeline_model_parallel_size = min( - args.pipeline_model_parallel_size, - (args.world_size // args.tensor_model_parallel_size)) - args.transformer_pipeline_model_parallel_size = ( - args.pipeline_model_parallel_size - 1 - if args.standalone_embedding_stage else - args.pipeline_model_parallel_size - ) - # Checks. - model_parallel_size = args.pipeline_model_parallel_size * \ - args.tensor_model_parallel_size - assert args.world_size % (model_parallel_size * args.context_parallel_size) == 0, \ - 'world size ({}) is not divisible by tensor parallel size ({}) times ' \ - 'pipeline parallel size ({}) times context parallel size ({})'.format( - args.world_size, args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, args.context_parallel_size) - args.data_parallel_size = args.world_size // (model_parallel_size * args.context_parallel_size) - if args.rank == 0: - print('using world size: {}, data-parallel-size: {}, ' - 'context-parallel-size: {} ' - 'tensor-model-parallel size: {}, ' - 'pipeline-model-parallel size: {} '.format( - args.world_size, args.data_parallel_size, - args.context_parallel_size, - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size), flush=True) - if args.pipeline_model_parallel_size > 1: - if args.pipeline_model_parallel_split_rank is not None: - assert args.pipeline_model_parallel_split_rank < \ - args.pipeline_model_parallel_size, 'split rank needs'\ - ' to be less than pipeline model parallel size ({})'.format( - args.pipeline_model_parallel_size) - - if args.tp_comm_overlap: - assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' - - - # Deprecated arguments - assert args.batch_size is None, '--batch-size argument is no longer ' \ - 'valid, use --micro-batch-size instead' - del args.batch_size - assert args.warmup is None, '--warmup argument is no longer valid, use ' \ - '--lr-warmup-fraction instead' - del args.warmup - assert args.model_parallel_size is None, '--model-parallel-size is no ' \ - 'longer valid, use --tensor-model-parallel-size instead' - del args.model_parallel_size - - if args.checkpoint_activations: - if args.rank == 0: - print('--checkpoint-activations is no longer valid, use --recompute-activations, ' - 'or, for more control, --recompute-granularity and --recompute-method.') - exit() - del args.checkpoint_activations - - if args.recompute_activations: - args.recompute_granularity = 'selective' - del args.recompute_activations - - # Set input defaults. - for key in defaults: - # For default to be valid, it should not be provided in the - # arguments that are passed to the program. We check this by - # ensuring the arg is set to None. - if getattr(args, key, None) is not None: - if args.rank == 0: - print('WARNING: overriding default arguments for {key}:{v} \ - with {key}:{v2}'.format(key=key, v=defaults[key], - v2=getattr(args, key)), - flush=True) - else: - setattr(args, key, defaults[key]) - - # Batch size. - assert args.micro_batch_size is not None - assert args.micro_batch_size > 0 - if args.global_batch_size is None: - args.global_batch_size = args.micro_batch_size * args.data_parallel_size - if args.rank == 0: - print('setting global batch size to {}'.format( - args.global_batch_size), flush=True) - assert args.global_batch_size > 0 - if args.num_layers_per_virtual_pipeline_stage is not None: - assert args.pipeline_model_parallel_size > 2, \ - 'pipeline-model-parallel size should be greater than 2 with ' \ - 'interleaved schedule' - assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ - 'number of layers should be divisible by the pipeline parallel size' - num_layers_per_pipeline_stage = args.num_layers // args.transformer_pipeline_model_parallel_size - assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \ - 'number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage' - args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \ - args.num_layers_per_virtual_pipeline_stage - else: - args.virtual_pipeline_model_parallel_size = None - # Overlap P2P communication is disabled if not using the interleaved schedule. - args.overlap_p2p_comm = False - if args.rank == 0: - print('WARNING: Setting args.overlap_p2p_comm to False since non-interleaved ' - 'schedule does not support overlapping p2p communication') - - # Parameters dtype. - args.params_dtype = torch.float - if args.fp16: - assert not args.bf16 - args.params_dtype = torch.half - if args.bf16: - assert not args.fp16 - args.params_dtype = torch.bfloat16 - # bfloat16 requires gradient accumulation and all-reduce to - # be done in fp32. - if not args.accumulate_allreduce_grads_in_fp32: - args.accumulate_allreduce_grads_in_fp32 = True - if args.rank == 0: - print('accumulate and all-reduce gradients in fp32 for ' - 'bfloat16 data type.', flush=True) - - if args.rank == 0: - print('using {} for parameters ...'.format(args.params_dtype), - flush=True) - - if args.dataloader_type is None: - args.dataloader_type = 'single' - - # Consumed tokens. - args.consumed_train_samples = 0 - args.consumed_valid_samples = 0 - - # Support for variable sequence lengths across batches/microbatches. - # set it if the dataloader supports generation of variable sequence lengths - # across batches/microbatches. Due to additional communication overhead - # during pipeline parallelism, it should not be set if sequence length - # is constant during training. - args.variable_seq_lengths = False - - # Iteration-based training. - if args.train_iters: - # If we use iteration-based training, make sure the - # sample-based options are off. - assert args.train_samples is None, \ - 'expected iteration-based training' - assert args.lr_decay_samples is None, \ - 'expected iteration-based learning rate decay' - assert args.lr_warmup_samples == 0, \ - 'expected iteration-based learning rate warmup' - assert args.rampup_batch_size is None, \ - 'expected no batch-size rampup for iteration-based training' - if args.lr_warmup_fraction is not None: - assert args.lr_warmup_iters == 0, \ - 'can only specify one of lr-warmup-fraction and lr-warmup-iters' - - # Sample-based training. - if args.train_samples: - # If we use sample-based training, make sure the - # iteration-based options are off. - assert args.train_iters is None, \ - 'expected sample-based training' - assert args.lr_decay_iters is None, \ - 'expected sample-based learning rate decay' - assert args.lr_warmup_iters == 0, \ - 'expected sample-based learnig rate warmup' - if args.lr_warmup_fraction is not None: - assert args.lr_warmup_samples == 0, \ - 'can only specify one of lr-warmup-fraction ' \ - 'and lr-warmup-samples' - - if args.num_layers is not None: - assert args.encoder_num_layers is None, \ - 'cannot have both num-layers and encoder-num-layers specified' - args.encoder_num_layers = args.num_layers - else: - assert args.encoder_num_layers is not None, \ - 'either num-layers or encoder-num-layers should be specified' - args.num_layers = args.encoder_num_layers - - # Check required arguments. - required_args = ['num_layers', 'hidden_size', 'num_attention_heads', - 'max_position_embeddings'] - for req_arg in required_args: - _check_arg_is_not_none(args, req_arg) - - # Checks. - if args.ffn_hidden_size is None: - if args.swiglu: - # reduce the dimnesion for MLP since projections happens on - # two linear layers. this keeps the number of paramters in - # the same ballpark as the counterpart with 4*h size - # we keep it a multiple of 64, which means the actual tensor size - # will be a multiple of 64 / tp_size - args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64 - else: - args.ffn_hidden_size = 4 * args.hidden_size - - if args.kv_channels is None: - assert args.hidden_size % args.num_attention_heads == 0 - args.kv_channels = args.hidden_size // args.num_attention_heads - - if args.seq_length is not None: - assert args.encoder_seq_length is None - args.encoder_seq_length = args.seq_length - else: - assert args.encoder_seq_length is not None - args.seq_length = args.encoder_seq_length - - if args.seq_length is not None: - assert args.max_position_embeddings >= args.seq_length - if args.decoder_seq_length is not None: - assert args.max_position_embeddings >= args.decoder_seq_length - if args.lr is not None: - assert args.min_lr <= args.lr - if args.save is not None: - assert args.save_interval is not None - # Mixed precision checks. - if args.fp16_lm_cross_entropy: - assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' - if args.fp32_residual_connection: - assert args.fp16 or args.bf16, \ - 'residual connection in fp32 only supported when using fp16 or bf16.' - - if args.weight_decay_incr_style == 'constant': - assert args.start_weight_decay is None - assert args.end_weight_decay is None - args.start_weight_decay = args.weight_decay - args.end_weight_decay = args.weight_decay - else: - assert args.start_weight_decay is not None - assert args.end_weight_decay is not None - - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) - # Persistent fused layer norm. - if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11): - args.no_persist_layer_norm = True - if args.rank == 0: - print('Persistent fused layer norm kernel is supported from ' - 'pytorch v1.11 (nvidia pytorch container paired with v1.11). ' - 'Defaulting to no_persist_layer_norm=True') - - # Activation recomputing. - if args.distribute_saved_activations: - assert args.tensor_model_parallel_size > 1, 'can distribute ' \ - 'recomputed activations only across tensor model ' \ - 'parallel groups' - assert args.recompute_granularity == 'full', \ - 'distributed recompute activations is only '\ - 'application to full recompute granularity' - assert args.recompute_method is not None, \ - 'for distributed recompute activations to work you '\ - 'need to use a recompute method ' - assert (TORCH_MAJOR, TORCH_MINOR) >= (1, 10), \ - 'distributed recompute activations are supported for pytorch ' \ - 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ - 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) - - if args.recompute_granularity == 'selective': - assert args.recompute_method is None, \ - 'recompute method is not yet supported for ' \ - 'selective recomputing granularity' - - # disable sequence parallelism when tp=1 - # to avoid change in numerics when - # sequence_parallelism is enabled. - if args.tensor_model_parallel_size == 1: - args.sequence_parallel = False - - # disable async_tensor_model_parallel_allreduce when - # model parallel memory optimization is enabled - if args.sequence_parallel: - args.async_tensor_model_parallel_allreduce = False - - if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": - if args.sequence_parallel: - raise RuntimeError( - "Using sequence parallelism requires setting the environment variable " - "CUDA_DEVICE_MAX_CONNECTIONS to 1") - if args.async_tensor_model_parallel_allreduce: - raise RuntimeError( - "Using async gradient all reduce requires setting the environment " - "variable CUDA_DEVICE_MAX_CONNECTIONS to 1") - - # Disable bias gelu fusion if we are disabling bias altogether - if not args.add_bias_linear: - args.bias_gelu_fusion = False - - # Retro checks. - if args.retro_add_retriever: - - # Sequence parallelism unsupported. - assert not args.sequence_parallel, \ - "retro currently does not support sequence parallelism." - - # Pipeline parallelism unsupported. - assert args.pipeline_model_parallel_size == 1, \ - "retro currently does not support pipeline parallelism." - - # Load retro args. - retro_args_path = get_retro_args_path(args.retro_workdir) - assert os.path.exists(retro_args_path), "retro workdir missing args.json" - with open(retro_args_path) as f: - retro_args = types.SimpleNamespace(**json.load(f)) - retro_args.retro_return_doc_ids = args.retro_return_doc_ids - retro_args.retro_gpt_retrieved_length = \ - args.retro_num_retrieved_chunks * \ - retro_args.retro_gpt_chunk_length - set_retro_args(retro_args) - - # Legacy RoPE arguments - if args.use_rotary_position_embeddings: - args.position_embedding_type = 'rope' - - # Would just need to add 'NoPE' as a position_embedding_type to support this, but for now - # don't allow it to keep things simple - if not args.add_position_embedding and args.position_embedding_type != 'rope': - raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type') - - # MoE Spec check - if args.num_experts is not None: - assert args.model_spec is None, "Model Spec must be None when using MoEs" - - # Expert parallelism check - if args.expert_model_parallel_size > 1: - assert args.num_experts is not None, "num_experts must be non None to use expert model parallelism" - assert args.num_experts % args.expert_model_parallel_size == 0, \ - "Number of experts should be a multiple of expert model parallel_size." - if args.tensor_model_parallel_size > 1: - assert args.sequence_parallel, \ - "When using expert parallelism and tensor parallelism, sequence parallelism must be used." - - # Print arguments. - _print_args("arguments", args) - retro_args = get_retro_args() - if retro_args and args != retro_args: - _print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank)) - - return args - - -def _print_args(title, args): - """Print arguments.""" - if args.rank == 0: - print(f'------------------------ {title} ------------------------', - flush=True) - str_list = [] - for arg in vars(args): - dots = '.' * (48 - len(arg)) - str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) - for arg in sorted(str_list, key=lambda x: x.lower()): - print(arg, flush=True) - print(f'-------------------- end of {title} ---------------------', - flush=True) - - -def _check_arg_is_not_none(args, arg): - assert getattr(args, arg) is not None, '{} argument is None'.format(arg) - -def core_transformer_config_from_args(args): - - # Translate args to core transformer configuration - kw_args = {} - for f in dataclasses.fields(TransformerConfig): - if hasattr(args, f.name): - kw_args[f.name] = getattr(args, f.name) - kw_args['persist_layer_norm'] = not args.no_persist_layer_norm - kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p - kw_args['layernorm_epsilon'] = args.norm_epsilon - kw_args['deallocate_pipeline_outputs'] = True - kw_args['pipeline_dtype'] = args.params_dtype - kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm - kw_args['num_moe_experts'] = args.num_experts - if args.swiglu: - kw_args['activation_func'] = F.silu - kw_args['gated_linear_unit'] = True - kw_args['bias_gelu_fusion'] = False - if args.init_method_xavier_uniform: - kw_args['init_method'] = torch.nn.init.xavier_uniform_ - kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_ - if args.group_query_attention: - kw_args['num_query_groups'] = args.num_query_groups - else: - kw_args['num_query_groups'] = None - - return TransformerConfig(**kw_args) - -def _add_transformer_engine_args(parser): - group = parser.add_argument_group(title='Transformer-Engine') - - group.add_argument('--fp8-format', default=None, - choices=['e4m3', 'hybrid'], - help='Which fp8 format scheme to use for FP8 tensors in the forward and backward pass', - dest='fp8') - group.add_argument('--fp8-margin', type=int, default=0, - help='Scaling margin for fp8', - dest='fp8_margin') - group.add_argument('--fp8-interval', type=int, default=1, - help='Scaling update interval for fp8', - dest='fp8_interval') - group.add_argument('--fp8-amax-history-len', type=int, default=1, - help='Number of steps for which amax history is recorded per tensor', - dest='fp8_amax_history_len') - group.add_argument('--fp8-amax-compute-algo', default='most_recent', - choices=['most_recent', 'max'], - help='Algorithm for computing amax from history', - dest='fp8_amax_compute_algo') - group.add_argument('--no-fp8-wgrad', action='store_false', - help='Execute wgrad in higher precision even for FP8 runs', - dest='fp8_wgrad') - group.add_argument('--transformer-impl', default='local', - choices=['local', 'transformer_engine'], - help='Which Transformer implementation to use.') - - return parser - -def _add_inference_args(parser): - group = parser.add_argument_group(title='inference') - - group.add_argument('--inference-batch-times-seqlen-threshold', - type=int, default=512, - help='During inference, if batch-size times ' - 'sequence-length is smaller than this threshold ' - 'then we will not use pipelining, otherwise we will.') - group.add_argument('--max-tokens-to-oom', - type=int, default=12000, - help='Maximum number of tokens during inference' - 'tokens here is # in prompt + # to generate' - 'Allows us to throw an error before OOM crashes server') - group.add_argument('--output-bert-embeddings', action='store_true', - help='Output Bert embeddings (via mean pooling) from ' - 'model, rather than its binary head output or entire ' - 'hidden batch.') - group.add_argument('--bert-embedder-type', default="megatron", - choices=["megatron", "huggingface"], - help='Select either Megatron or Huggingface as the ' - 'Bert embedder.') - - return parser - - -def _add_retro_args(parser): - group = parser.add_argument_group(title='retro') - - group.add_argument('--retro-workdir', default=None, - help='Retro working directory, which contains the ' - 'preprocessed data for for pretraining. This directory ' - 'is built during preprocessing (see ' - 'tools/retro/README.md), and contains subdirectories ' - 'for the chunk database and pretraining neighbors.') - group.add_argument('--retro-add-retriever', - action='store_true', default=False, - help='Add a retriever to the transformer, for use in ' - 'pretraining a Retro model.') - group.add_argument('--retro-cyclic-train-iters', type=int, default=None, - help='Set number of training iterations for cyclic ' - 'Retro training.') - group.add_argument('--retro-encoder-layers', type=int, default=2, - help='Number of layers to use for the retrieval ' - 'encoder.') - group.add_argument('--retro-encoder-hidden-dropout', - type=float, default=0.1, help='Hidden dropout for ' - 'retrieval encoder.') - group.add_argument('--retro-encoder-attention-dropout', - type=float, default=0.1, help='Attention dropout for ' - 'retrieval encoder.') - group.add_argument("--retro-num-neighbors", type=int, default=2, - help='Number of neighbors to retrieve during ' - 'pretraining.') - group.add_argument("--retro-num-retrieved-chunks", type=int, default=2, - help='Number of chunks to retrieve from the retrieval ' - 'database.') - group.add_argument("--retro-return-doc-ids", action="store_true", - help="Turn this on when preprocessing retro data.") - - # Enforce argument naming convention. - for action in group._group_actions: - prefix = action.dest.split("_")[0] - assert prefix == "retro", \ - "Retro args must be prefixed with '--retro-*', for consistent " \ - "styling. Please fix '%s'." % ", ".join(action.option_strings) - - return parser - - -def _add_network_size_args(parser): - group = parser.add_argument_group(title='network size') - - group.add_argument('--num-layers', type=int, default=None, - help='Number of transformer layers.') - group.add_argument('--encoder-num-layers', type=int, default=None, - help='Number of encoder transformer layers.') - group.add_argument('--decoder-num-layers', type=int, default=None, - help='Number of decoder transformer layers.') - group.add_argument('--hidden-size', type=int, default=None, - help='Tansformer hidden size.') - group.add_argument('--ffn-hidden-size', type=int, default=None, - help='Transformer Feed-Forward Network hidden size. ' - 'This is set to 4*hidden-size if not provided') - group.add_argument('--num-attention-heads', type=int, default=None, - help='Number of transformer attention heads.') - group.add_argument('--kv-channels', type=int, default=None, - help='Projection weights dimension in multi-head ' - 'attention. This is set to ' - ' args.hidden_size // args.num_attention_heads ' - 'if not provided.') - group.add_argument('--group-query-attention', action='store_true', - help='Use group-query attention.') - group.add_argument('--num-query-groups', type=int, default=1) - - group.add_argument('--max-position-embeddings', type=int, default=None, - help='Maximum number of position embeddings to use. ' - 'This is the size of position embedding.') - group.add_argument('--position-embedding-type', type=str, default='learned_absolute', - choices=['learned_absolute', 'rope'], - help='Position embedding type.') - group.add_argument('--use-rotary-position-embeddings', action='store_true', - help='Use rotary positional embeddings or not. ' - 'Deprecated: use --position-embedding-type') - group.add_argument('--rotary-percent', type=float, default=1.0, - help='Percent of rotary dimension to use, default 100%%') - group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None, - help='Sequence length interpolation factor for rotary embeddings.') - group.add_argument('--no-position-embedding', - action='store_false', - help='Disable position embedding. Deprecated: use --position-embedding-type', - dest='add_position_embedding') - group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, - help='Pad the vocab size to be divisible by this value.' - 'This is added for computational efficieny reasons.') - group.add_argument('--normalization', default='LayerNorm', - choices=['LayerNorm', 'RMSNorm'], - help='Which normalization technique to use.') - group.add_argument('--norm-epsilon', type=float, default=1e-5, - help='Epsilon for layer norm and RMS norm.') - group.add_argument('--apply-layernorm-1p', action='store_true', - help='Adjust LayerNorm weights such that they are centered ' - 'around zero. This improves numerical stability.') - group.add_argument('--apply-residual-connection-post-layernorm', - action='store_true', - help='If set, use original BERT residula connection ' - 'ordering.') - group.add_argument('--openai-gelu', action='store_true', - help='Use OpenAIs GeLU implementation. This option' - 'should not be used unless for backward compatibility' - 'reasons.') - group.add_argument('--squared-relu', action='store_true', - help='Use squared relu activation instead of default gelu') - group.add_argument('--swiglu', action='store_true', - help='Use gated linear units and SiLU activation instead of default gelu') - group.add_argument('--onnx-safe', type=bool, required=False, - help='Use workarounds for known problems with ' - 'Torch ONNX exporter') - group.add_argument('--bert-no-binary-head', action='store_false', - help='Disable BERT binary head.', - dest='bert_binary_head') - group.add_argument('--num-experts', type=int, default=None, - help='Number of Experts in Switch Transformer (None means no Switch)') - group.add_argument('--untie-embeddings-and-output-weights', action='store_true', - help='Untie embeddings and output weights.'), - return parser - - -def _add_logging_args(parser): - group = parser.add_argument_group(title='logging') - - group.add_argument('--log-params-norm', action='store_true', - help='If set, calculate and log parameters norm.') - group.add_argument('--log-num-zeros-in-grad', action='store_true', - help='If set, calculate and log the number of zeros in gradient.') - group.add_argument('--timing-log-level', type=int, - default=0, choices=range(0,3), - help='Granularity level to measure and report timing. ' - ' 0: report only iteration time and make sure timing ' - ' does not introduce extra overhead.' - ' 1: report timing for operations that are executed ' - ' very limited times (basically once) during ' - ' each iteration (such as gradient all-reduce) ' - ' 2: report timing for operations that migh be ' - ' executed numerous times during each iteration. ' - 'Note that setting the level to 1 or 2 might ' - 'cause increase in iteration time.') - group.add_argument('--no-barrier-with-level-1-timing', action='store_false', - help='If not set, use barrier with level 1 time ' - 'measurements. Note that this is up to the user ' - 'to make sure calling barrier with their timers ' - 'will not result in hangs. This can happen if for ' - 'example the user adds a level 1 timer that is not ' - 'called by all ranks.', - dest='barrier_with_L1_time') - group.add_argument('--timing-log-option', type=str, default='minmax', - choices=['max', 'minmax', 'all'], - help='Options for logging timing:' - ' max: report the max timing across all ranks' - ' minmax: report min and max timings across all ranks' - ' all: report timings of all ranks.') - group.add_argument('--tensorboard-log-interval', type=int, default=1, - help='Report to tensorboard interval.') - group.add_argument('--tensorboard-queue-size', type=int, default=1000, - help='Size of the tensorboard queue for pending events ' - 'and summaries before one of the ‘add’ calls forces a ' - 'flush to disk.') - group.add_argument('--log-timers-to-tensorboard', action='store_true', - help='If set, write timers to tensorboard.') - group.add_argument('--log-batch-size-to-tensorboard', action='store_true', - help='If set, write batch-size to tensorboard.') - group.add_argument('--no-log-learnig-rate-to-tensorboard', - action='store_false', - help='Disable learning rate logging to tensorboard.', - dest='log_learning_rate_to_tensorboard') - group.add_argument('--no-log-loss-scale-to-tensorboard', - action='store_false', - help='Disable loss-scale logging to tensorboard.', - dest='log_loss_scale_to_tensorboard') - group.add_argument('--log-validation-ppl-to-tensorboard', - action='store_true', - help='If set, write validation perplexity to ' - 'tensorboard.') - group.add_argument('--log-memory-to-tensorboard', - action='store_true', - help='Enable memory logging to tensorboard.') - group.add_argument('--log-world-size-to-tensorboard', - action='store_true', - help='Enable world size logging to tensorboard.') - group.add_argument('--wandb-project', type=str, default='', - help='The wandb project name. Ignore wandb by default.') - group.add_argument('--wandb-exp-name', type=str, default='', - help='The wandb experiment name.') - group.add_argument('--wandb-save-dir', type=str, default='', - help='Path to save the wandb results locally.') - return parser - - -def _add_regularization_args(parser): - group = parser.add_argument_group(title='regularization') - - group.add_argument('--attention-dropout', type=float, default=0.1, - help='Post attention dropout probability.') - group.add_argument('--hidden-dropout', type=float, default=0.1, - help='Dropout probability for hidden state transformer.') - group.add_argument('--weight-decay', type=float, default=0.01, - help='Weight decay coefficient for L2 regularization.') - group.add_argument('--start-weight-decay', type=float, - help='Initial weight decay coefficient for L2 regularization.') - group.add_argument('--end-weight-decay', type=float, - help='End of run weight decay coefficient for L2 regularization.') - group.add_argument('--weight-decay-incr-style', type=str, default='constant', - choices=['constant', 'linear', 'cosine'], - help='Weight decay increment function.') - group.add_argument('--clip-grad', type=float, default=1.0, - help='Gradient clipping based on global L2 norm.') - group.add_argument('--adam-beta1', type=float, default=0.9, - help='First coefficient for computing running averages ' - 'of gradient and its square') - group.add_argument('--adam-beta2', type=float, default=0.999, - help='Second coefficient for computing running averages ' - 'of gradient and its square') - group.add_argument('--adam-eps', type=float, default=1e-08, - help='Term added to the denominator to improve' - 'numerical stability') - group.add_argument('--sgd-momentum', type=float, default=0.9, - help='Momentum factor for sgd') - return parser - - -def _add_training_args(parser): - group = parser.add_argument_group(title='training') - - group.add_argument('--micro-batch-size', type=int, default=None, - help='Batch size per model instance (local batch size). ' - 'Global batch size is local batch size times data ' - 'parallel size times number of micro batches.') - group.add_argument('--batch-size', type=int, default=None, - help='Old batch size parameter, do not use. ' - 'Use --micro-batch-size instead') - group.add_argument('--global-batch-size', type=int, default=None, - help='Training batch size. If set, it should be a ' - 'multiple of micro-batch-size times data-parallel-size. ' - 'If this value is None, then ' - 'use micro-batch-size * data-parallel-size as the ' - 'global batch size. This choice will result in 1 for ' - 'number of micro-batches.') - group.add_argument('--rampup-batch-size', nargs='*', default=None, - help='Batch size ramp up with the following values:' - ' --rampup-batch-size ' - ' ' - ' ' - 'For example:' - ' --rampup-batch-size 16 8 300000 \ ' - ' --global-batch-size 1024' - 'will start with global batch size 16 and over ' - ' (1024 - 16) / 8 = 126 intervals will increase' - 'the batch size linearly to 1024. In each interval' - 'we will use approximately 300000 / 126 = 2380 samples.') - group.add_argument('--recompute-activations', action='store_true', - help='recompute activation to allow for training ' - 'with larger models, sequences, and batch sizes.') - group.add_argument('--recompute-granularity', type=str, default=None, - choices=['full', 'selective'], - help='Checkpoint activations to allow for training ' - 'with larger models, sequences, and batch sizes. ' - 'It is supported at two granularities 1) full: ' - 'whole transformer layer is recomputed, ' - '2) selective: core attention part of the transformer ' - 'layer is recomputed.') - group.add_argument('--no-check-for-nan-in-loss-and-grad', action='store_false', - help='Check for NaNs in loss and grad', - dest='check_for_nan_in_loss_and_grad') - group.add_argument('--distribute-saved-activations', - action='store_true', - help='If set, distribute recomputed activations ' - 'across model parallel group.') - group.add_argument('--recompute-method', type=str, default=None, - choices=['uniform', 'block'], - help='1) uniform: uniformly divide the total number of ' - 'Transformer layers and recompute the input activation of ' - 'each divided chunk at specified granularity, ' - '2) recompute the input activations of only a set number of ' - 'individual Transformer layers per pipeline stage and do the ' - 'rest without any recomputing at specified granularity' - 'default) do not apply activations recompute to any layers') - group.add_argument('--recompute-num-layers', type=int, default=None, - help='1) uniform: the number of Transformer layers in each ' - 'uniformly divided recompute unit, ' - '2) block: the number of individual Transformer layers ' - 'to recompute within each pipeline stage.') - group.add_argument('--profile', action='store_true', - help='Enable nsys profiling. When using this option, nsys ' - 'options should be specified in commandline. An example ' - 'nsys commandline is `nsys profile -s none -t nvtx,cuda ' - '-o --force-overwrite true ' - '--capture-range=cudaProfilerApi ' - '--capture-range-end=stop`.') - group.add_argument('--profile-step-start', type=int, default=10, - help='Gloable step to start profiling.') - group.add_argument('--profile-step-end', type=int, default=12, - help='Gloable step to stop profiling.') - group.add_argument('--profile-ranks', nargs='+', type=int, default=[0], - help='Global ranks to profile.') - group.add_argument('--tp-comm-overlap', action='store_true', help = 'Enables the ' - ' overlap of Tensor parallel communication and GEMM kernels.') - group.add_argument('--tp-comm-overlap-cfg', type=str, default=None, - help = 'Config file when tp_comm_overlap is enabled.') - group.add_argument('--disable-tp-comm-split-ag', action='store_false', - help = 'Disables the All-Gather overlap with fprop GEMM.', - dest='tp_comm_split_ag') - group.add_argument('--disable-tp-comm-split-rs', action='store_false', - help = 'Disables the Reduce-Scatter overlap with fprop GEMM.', - dest='tp_comm_split_rs') - group.add_argument('--disable-tp-comm-bulk-dgrad', action='store_false', - help = 'Disables the All-Gather overlap with bprop activation gradient GEMM.', - dest='tp_comm_bulk_dgrad') - group.add_argument('--disable-tp-comm-bulk-wgrad', action='store_false', - help = 'Disables the Reduce-Scatter overlap with bprop weight gradient GEMM.', - dest='tp_comm_bulk_wgrad') - - - # deprecated - group.add_argument('--checkpoint-activations', action='store_true', - help='Checkpoint activation to allow for training ' - 'with larger models, sequences, and batch sizes.') - group.add_argument('--train-iters', type=int, default=None, - help='Total number of iterations to train over all ' - 'training runs. Note that either train-iters or ' - 'train-samples should be provided.') - group.add_argument('--train-samples', type=int, default=None, - help='Total number of samples to train over all ' - 'training runs. Note that either train-iters or ' - 'train-samples should be provided.') - group.add_argument('--log-interval', type=int, default=100, - help='Report loss and timing interval.') - group.add_argument('--exit-interval', type=int, default=None, - help='Exit the program after the iteration is divisible ' - 'by this value.') - group.add_argument('--exit-duration-in-mins', type=int, default=None, - help='Exit the program after this many minutes.') - group.add_argument('--exit-signal-handler', action='store_true', - help='Dynamically save the checkpoint and shutdown the ' - 'training if SIGTERM is received') - group.add_argument('--tensorboard-dir', type=str, default=None, - help='Write TensorBoard logs to this directory.') - group.add_argument('--no-masked-softmax-fusion', - action='store_false', - help='Disable fusion of query_key_value scaling, ' - 'masking, and softmax.', - dest='masked_softmax_fusion') - group.add_argument('--no-bias-gelu-fusion', action='store_false', - help='Disable bias and gelu fusion.', - dest='bias_gelu_fusion') - group.add_argument('--no-bias-dropout-fusion', action='store_false', - help='Disable bias and dropout fusion.', - dest='bias_dropout_fusion') - group.add_argument('--use-flash-attn', action='store_true', - help='use FlashAttention implementation of attention. ' - 'https://arxiv.org/abs/2205.14135') - group.add_argument('--disable-bias-linear', action='store_false', - help='Disable bias in the linear layers', - dest='add_bias_linear') - group.add_argument('--optimizer', type=str, default='adam', - choices=['adam', 'sgd'], - help='Optimizer function') - group.add_argument('--dataloader-type', type=str, default=None, - choices=['single', 'cyclic'], - help='Single pass vs multiple pass data loader') - group.add_argument('--no-async-tensor-model-parallel-allreduce', - action='store_false', - help='Disable asynchronous execution of ' - 'tensor-model-parallel all-reduce with weight ' - 'gradient compuation of a column-linear layer.', - dest='async_tensor_model_parallel_allreduce') - group.add_argument('--no-persist-layer-norm', action='store_true', - help='Disable using persistent fused layer norm kernel. ' - 'This kernel supports only a set of hidden sizes. Please ' - 'check persist_ln_hidden_sizes if your hidden ' - 'size is supported.') - group.add_argument('--sequence-parallel', action='store_true', - help='Enable sequence parallel optimization.') - group.add_argument('--no-gradient-accumulation-fusion', - action='store_false', - help='Disable fusing gradient accumulation to weight ' - 'gradient computation of linear layers', - dest='gradient_accumulation_fusion') - group.add_argument('--use-mcore-models', action='store_true', - help='Use the implementation from megatron core', - dest='use_mcore_models') - group.add_argument('--expert-parallel', action='store_true', - help='Enable expert parallel optimization.') - group.add_argument('--manual-gc', action='store_true', - help='Disable the threshold-based default garbage ' - 'collector and trigger the garbage collection manually. ' - 'Manual garbage collection helps to align the timing of ' - 'the collection across ranks which mitigates the impact ' - 'of CPU-associated jitters. When the manual gc is enabled, ' - 'garbage collection is performed only at the start and the ' - 'end of the validation routine by default.') - group.add_argument('--manual-gc-interval', type=int, default=0, - help='Training step interval to trigger manual garbage ' - 'collection. When the value is set to 0, garbage ' - 'collection is not triggered between training steps.') - group.add_argument('--no-manual-gc-eval', action='store_false', - help='When using manual garbage collection, disable ' - 'garbage collection at the start and the end of each ' - 'evaluation run.', dest='manual_gc_eval') - - return parser - - -def _add_initialization_args(parser): - group = parser.add_argument_group(title='initialization') - - group.add_argument('--seed', type=int, default=1234, - help='Random seed used for python, numpy, ' - 'pytorch, and cuda.') - group.add_argument('--data-parallel-random-init', action='store_true', - help='Enable random initialization of params ' - 'across data parallel ranks') - group.add_argument('--init-method-std', type=float, default=0.02, - help='Standard deviation of the zero mean normal ' - 'distribution used for weight initialization.') - group.add_argument('--init-method-xavier-uniform', action='store_true', - help='Enable Xavier uniform parameter initialization') - - return parser - - -def _add_learning_rate_args(parser): - group = parser.add_argument_group(title='learning rate') - - group.add_argument('--lr', type=float, default=None, - help='Initial learning rate. Depending on decay style ' - 'and initial warmup, the learing rate at each ' - 'iteration would be different.') - group.add_argument('--lr-decay-style', type=str, default='linear', - choices=['constant', 'linear', 'cosine', 'inverse-square-root'], - help='Learning rate decay function.') - group.add_argument('--lr-decay-iters', type=int, default=None, - help='number of iterations to decay learning rate over,' - ' If None defaults to `--train-iters`') - group.add_argument('--lr-decay-samples', type=int, default=None, - help='number of samples to decay learning rate over,' - ' If None defaults to `--train-samples`') - group.add_argument('--lr-warmup-fraction', type=float, default=None, - help='fraction of lr-warmup-(iters/samples) to use ' - 'for warmup (as a float)') - group.add_argument('--lr-warmup-iters', type=int, default=0, - help='number of iterations to linearly warmup ' - 'learning rate over.') - group.add_argument('--lr-warmup-samples', type=int, default=0, - help='number of samples to linearly warmup ' - 'learning rate over.') - group.add_argument('--lr-warmup-init', type=float, default=0.0, - help='Initial value for learning rate warmup. The ' - 'scheduler starts warmup from this value.') - group.add_argument('--warmup', type=int, default=None, - help='Old lr warmup argument, do not use. Use one of the' - '--lr-warmup-* arguments above') - group.add_argument('--min-lr', type=float, default=0.0, - help='Minumum value for learning rate. The scheduler' - 'clip values below this threshold.') - group.add_argument('--override-opt_param-scheduler', action='store_true', - help='Reset the values of the scheduler (learning rate,' - 'warmup iterations, minimum learning rate, maximum ' - 'number of iterations, and decay style from input ' - 'arguments and ignore values from checkpoints. Note' - 'that all the above values will be reset.') - group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true', - help='Use checkpoint to set the values of the scheduler ' - '(learning rate, warmup iterations, minimum learning ' - 'rate, maximum number of iterations, and decay style ' - 'from checkpoint and ignore input arguments.') - - return parser - - -def _add_checkpointing_args(parser): - group = parser.add_argument_group(title='checkpointing') - - group.add_argument('--save', type=str, default=None, - help='Output directory to save checkpoints to.') - group.add_argument('--save-interval', type=int, default=None, - help='Number of iterations between checkpoint saves.') - group.add_argument('--no-save-optim', action='store_true', default=None, - help='Do not save current optimizer.') - group.add_argument('--no-save-rng', action='store_true', default=None, - help='Do not save current rng state.') - group.add_argument('--load', type=str, default=None, - help='Directory containing a model checkpoint.') - group.add_argument('--no-load-optim', action='store_true', default=None, - help='Do not load optimizer when loading checkpoint.') - group.add_argument('--no-load-rng', action='store_true', default=None, - help='Do not load rng state when loading checkpoint.') - group.add_argument('--finetune', action='store_true', - help='Load model for finetuning. Do not load optimizer ' - 'or rng state from checkpoint and set iteration to 0. ' - 'Assumed when loading a release checkpoint.') - group.add_argument('--no-initialization', action='store_false', - help='Do not perform initialization when building model, ' - 'can reduce startup time when definitely loading from a ' - 'checkpoint', - dest='perform_initialization') - group.add_argument('--use-checkpoint-args', action='store_true', - help='Override any command line arguments with arguments ' - 'from the checkpoint') - group.add_argument('--exit-on-missing-checkpoint', action='store_true', - help="If '--load' is set, but checkpoint is not found " - "(e.g., path typo), then exit instead of random " - "initialization.") - - return parser - - -def _add_mixed_precision_args(parser): - group = parser.add_argument_group(title='mixed precision') - - group.add_argument('--fp16', action='store_true', - help='Run model in fp16 mode.') - group.add_argument('--bf16', action='store_true', - help='Run model in bfloat16 mode.') - group.add_argument('--loss-scale', type=float, default=None, - help='Static loss scaling, positive power of 2 ' - 'values can improve fp16 convergence. If None, dynamic' - 'loss scaling is used.') - group.add_argument('--initial-loss-scale', type=float, default=2**32, - help='Initial loss-scale for dynamic loss scaling.') - group.add_argument('--min-loss-scale', type=float, default=1.0, - help='Minimum loss scale for dynamic loss scale.') - group.add_argument('--loss-scale-window', type=float, default=1000, - help='Window over which to raise/lower dynamic scale.') - group.add_argument('--hysteresis', type=int, default=2, - help='hysteresis for dynamic loss scaling') - group.add_argument('--fp32-residual-connection', action='store_true', - help='Move residual connections to fp32.') - group.add_argument('--no-query-key-layer-scaling', action='store_false', - help='Do not scale Q * K^T by 1 / layer-number.', - dest='apply_query_key_layer_scaling') - group.add_argument('--attention-softmax-in-fp32', action='store_true', - help='Run attention masking and softmax in fp32. ' - 'This flag is ignored unless ' - '--no-query-key-layer-scaling is specified.') - group.add_argument('--accumulate-allreduce-grads-in-fp32', - action='store_true', - help='Gradient accumulation and all-reduce in fp32.') - group.add_argument('--fp16-lm-cross-entropy', action='store_true', - help='Move the cross entropy unreduced loss calculation' - 'for lm head to fp16.') - - return parser - - -def _add_distributed_args(parser): - group = parser.add_argument_group(title='distributed') - - group.add_argument('--tensor-model-parallel-size', type=int, default=1, - help='Degree of tensor model parallelism.') - group.add_argument('--pipeline-model-parallel-size', type=int, default=1, - help='Degree of pipeline model parallelism.') - group.add_argument('--pipeline-model-parallel-split-rank', - type=int, default=None, - help='Rank where encoder and decoder should be split.') - group.add_argument('--model-parallel-size', type=int, default=None, - help='Old model parallel argument, do not use. Use ' - '--tensor-model-parallel-size instead.') - group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, - help='Number of layers per virtual pipeline stage') - group.add_argument('--no-overlap-p2p-communication', action='store_false', - help='overlap pipeline parallel communication with forward and backward chunks', - dest='overlap_p2p_comm') - group.add_argument('--distributed-backend', default='nccl', - choices=['nccl', 'gloo'], - help='Which backend to use for distributed training.') - group.add_argument('--distributed-timeout-minutes', type=int, default=10, - help='Timeout minutes for torch.distributed.') - group.add_argument('--overlap-grad-reduce', action='store_true', - default=False, help='If set, overlap DDP grad reduce.') - group.add_argument('--no-delay-grad-reduce', action='store_false', - help='If not set, delay grad reduction in all but first PP stage.', - dest='delay_grad_reduce') - group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', - help='If not set, use scatter/gather to optimize communication of tensors in pipeline.', - dest='scatter_gather_tensors_in_pipeline') - group.add_argument('--use-ring-exchange-p2p', action='store_true', - default=False, help='If set, use custom-built ring exchange ' - 'for p2p communications. Note that this option will require ' - 'a custom built image that support ring-exchange p2p.') - group.add_argument('--local_rank', type=int, default=None, - help='local rank passed from distributed launcher.') - group.add_argument('--lazy-mpu-init', type=bool, required=False, - help='If set to True, initialize_megatron() ' - 'skips DDP initialization and returns function to ' - 'complete it instead.Also turns on ' - '--use-cpu-initialization flag. This is for ' - 'external DDP manager.' ) - group.add_argument('--use-cpu-initialization', action='store_true', - default=None, help='If set, affine parallel weights ' - 'initialization uses CPU' ) - group.add_argument('--empty-unused-memory-level', default=0, type=int, - choices=[0, 1, 2], - help='Call torch.cuda.empty_cache() each iteration ' - '(training and eval), to reduce fragmentation.' - '0=off, 1=moderate, 2=aggressive.') - group.add_argument('--standalone-embedding-stage', action='store_true', - default=False, help='If set, *input* embedding layer ' - 'is placed on its own pipeline stage, without any ' - 'transformer layers. (For T5, this flag currently only ' - 'affects the encoder embedding.)') - group.add_argument('--use-distributed-optimizer', action='store_true', - help='Use distributed optimizer.') - group.add_argument('--expert-model-parallel-size', type=int, default=1, - help='Degree of expert model parallelism.') - group.add_argument('--context-parallel-size', type=int, default=1, - help='Degree of context parallelism.') - return parser - - -def _add_validation_args(parser): - group = parser.add_argument_group(title='validation') - - group.add_argument('--eval-iters', type=int, default=100, - help='Number of iterations to run for evaluation' - 'validation/test for.') - group.add_argument('--eval-interval', type=int, default=1000, - help='Interval between running evaluation on ' - 'validation set.') - group.add_argument('--skip-train', action='store_true', - default=False, help='If set, bypass the training loop, ' - 'optionally do evaluation for validation/test, and exit.') - - return parser - - -def _add_data_args(parser): - group = parser.add_argument_group(title='data and dataloader') - - group.add_argument('--data-path', nargs='*', default=None, - help='Path to the training dataset. Accepted format:' - '1) a single data path, 2) multiple datasets in the' - 'form: dataset1-weight dataset1-path dataset2-weight ' - 'dataset2-path ... It is used with --split when a ' - 'single dataset used for all three: train, valid ' - 'and test. It is exclusive to the other ' - '--*-data-path args') - group.add_argument('--split', type=str, default='969, 30, 1', - help='Comma-separated list of proportions for training,' - ' validation, and test split. For example the split ' - '`90,5,5` will use 90%% of data for training, 5%% for ' - 'validation and 5%% for test.') - group.add_argument('--train-data-path', nargs='*', default=None, - help='Path to the training dataset. Accepted format:' - '1) a single data path, 2) multiple datasets in the' - 'form: dataset1-weight dataset1-path dataset2-weight ' - 'dataset2-path ...') - group.add_argument('--valid-data-path', nargs='*', default=None, - help='Path to the validation dataset. Accepted format:' - '1) a single data path, 2) multiple datasets in the' - 'form: dataset1-weight dataset1-path dataset2-weight ' - 'dataset2-path ...') - group.add_argument('--test-data-path', nargs='*', default=None, - help='Path to the test dataset. Accepted format:' - '1) a single data path, 2) multiple datasets in the' - 'form: dataset1-weight dataset1-path dataset2-weight ' - 'dataset2-path ...') - group.add_argument('--data-cache-path', default=None, - help='Path to a directory to hold cached index files.') - - group.add_argument('--vocab-size', type=int, default=None, - help='Size of vocab before EOD or padding.') - group.add_argument('--vocab-file', type=str, default=None, - help='Path to the vocab file.') - group.add_argument('--merge-file', type=str, default=None, - help='Path to the BPE merge file.') - group.add_argument('--vocab-extra-ids', type=int, default=0, - help='Number of additional vocabulary tokens. ' - 'They are used for span masking in the T5 model') - group.add_argument('--seq-length', type=int, default=None, - help='Maximum sequence length to process.') - group.add_argument('--encoder-seq-length', type=int, default=None, - help='Maximum encoder sequence length to process.' - 'This should be exclusive of --seq-length') - group.add_argument('--decoder-seq-length', type=int, default=None, - help="Maximum decoder sequence length to process.") - group.add_argument('--retriever-seq-length', type=int, default=256, - help='Maximum sequence length for the biencoder model ' - 'for retriever') - group.add_argument('--sample-rate', type=float, default=1.0, - help='sample rate for training data. Supposed to be 0 ' - ' < sample_rate < 1') - group.add_argument('--mask-prob', type=float, default=0.15, - help='Probability of replacing a token with mask.') - group.add_argument('--short-seq-prob', type=float, default=0.1, - help='Probability of producing a short sequence.') - group.add_argument('--num-workers', type=int, default=2, - help="Dataloader number of workers.") - group.add_argument('--tokenizer-type', type=str, - default=None, - choices=['BertWordPieceLowerCase', - 'BertWordPieceCase', - 'GPT2BPETokenizer', - 'SentencePieceTokenizer', - 'GPTSentencePieceTokenizer', - 'Llama2Tokenizer', - 'NullTokenizer'], - help='What type of tokenizer to use.') - group.add_argument('--tokenizer-model', type=str, default=None, - help='Sentencepiece tokenizer model.') - group.add_argument('--reset-position-ids', action='store_true', - help='Reset posistion ids after end-of-document token.') - group.add_argument('--reset-attention-mask', action='store_true', - help='Reset self attention maske after ' - 'end-of-document token.') - group.add_argument('--eod-mask-loss', action='store_true', - help='Mask loss for the end of document tokens.') - - return parser - - -def _add_autoresume_args(parser): - group = parser.add_argument_group(title='autoresume') - - group.add_argument('--adlr-autoresume', action='store_true', - help='Enable autoresume on adlr cluster.') - group.add_argument('--adlr-autoresume-interval', type=int, default=1000, - help='Intervals over which check for autoresume' - 'termination signal') - - return parser - - -def _add_biencoder_args(parser): - group = parser.add_argument_group(title='biencoder') - - # network size - group.add_argument('--ict-head-size', type=int, default=None, - help='Size of block embeddings to be used in ICT and ' - 'REALM (paper default: 128)') - group.add_argument('--biencoder-projection-dim', type=int, default=0, - help='Size of projection head used in biencoder (paper' - ' default: 128)') - group.add_argument('--biencoder-shared-query-context-model', action='store_true', - help='Whether to share the parameters of the query ' - 'and context models or not') - - # checkpointing - group.add_argument('--ict-load', type=str, default=None, - help='Directory containing an ICTBertModel checkpoint') - group.add_argument('--bert-load', type=str, default=None, - help='Directory containing an BertModel checkpoint ' - '(needed to start ICT and REALM)') - - # data - group.add_argument('--titles-data-path', type=str, default=None, - help='Path to titles dataset used for ICT') - group.add_argument('--query-in-block-prob', type=float, default=0.1, - help='Probability of keeping query in block for ' - 'ICT dataset') - group.add_argument('--use-one-sent-docs', action='store_true', - help='Whether to use one sentence documents in ICT') - group.add_argument('--evidence-data-path', type=str, default=None, - help='Path to Wikipedia Evidence frm DPR paper') - - # training - group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, - default=[], help="Which top-k accuracies to report " - "(e.g. '1 5 20')") - group.add_argument('--retriever-score-scaling', action='store_true', - help='Whether to scale retriever scores by inverse ' - 'square root of hidden size') - - # faiss index - group.add_argument('--block-data-path', type=str, default=None, - help='Where to save/load BlockData to/from') - group.add_argument('--embedding-path', type=str, default=None, - help='Where to save/load Open-Retrieval Embedding' - ' data to/from') - - # indexer - group.add_argument('--indexer-batch-size', type=int, default=128, - help='How large of batches to use when doing indexing ' - 'jobs') - group.add_argument('--indexer-log-interval', type=int, default=1000, - help='After how many batches should the indexer ' - 'report progress') - return parser - - -def _add_vision_args(parser): - group = parser.add_argument_group(title="vision") - - # general vision arguements - group.add_argument('--num-classes', type=int, default=1000, - help='num of classes in vision classificaiton task') - group.add_argument('--img-h', type=int, default=224, - help='Image height for vision classification task') - group.add_argument('--img-w', type=int, default=224, - help='Image height for vision classification task') - group.add_argument('--num-channels', type=int, default=3, - help='Number of channels in input image data') - group.add_argument('--patch-dim', type=int, default=16, - help='patch dimension') - group.add_argument('--classes-fraction', type=float, default=1.0, - help='training with fraction of classes.') - group.add_argument('--data-per-class-fraction', type=float, default=1.0, - help='training with fraction of data per class.') - group.add_argument('--no-data-sharding', action='store_false', - help='Disable data sharding.', - dest='data_sharding') - group.add_argument('--head-lr-mult', type=float, default=1.0, - help='learning rate multiplier for head during finetuning') - - # pretraining type and backbone selection` - group.add_argument('--vision-pretraining', action='store_true', - help='flag to indicate vision pretraining') - group.add_argument('--vision-pretraining-type', type=str, default='classify', - choices=['classify', 'inpaint', 'dino'], - help='pretraining objectives') - group.add_argument('--vision-backbone-type', type=str, default='vit', - choices=['vit', 'mit', 'swin'], - help='backbone types types') - group.add_argument('--swin-backbone-type', type=str, default='tiny', - choices=['tiny', 'base', 'h3'], - help='pretraining objectives') - - # inpainting arguments - group.add_argument('--mask-type', type=str, default='random', - choices=['random', 'row'], - help='mask types') - group.add_argument('--mask-factor', type=float, default=1.0, - help='mask size scaling parameter') - - # dino arguments - group.add_argument('--iter-per-epoch', type=int, default=1250, - help='iterations per epoch') - group.add_argument('--dino-local-img-size', type=int, default=96, - help='Image size for vision classification task') - group.add_argument('--dino-local-crops-number', type=int, default=10, - help='Number of local crops') - group.add_argument('--dino-head-hidden-size', type=int, default=2048, - help='Hidden dimension size in dino head') - group.add_argument('--dino-bottleneck-size', type=int, default=256, - help='Bottle neck dimension in dino head ') - group.add_argument('--dino-freeze-last-layer', type=float, default=1, - help='Freezing last layer weights') - group.add_argument('--dino-norm-last-layer', action='store_true', - help='Disable Norm in last layer.') - group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04, - help='warump teacher temperature') - group.add_argument('--dino-teacher-temp', type=float, default=0.07, - help='teacher temperature') - group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30, - help='warmup teacher temperaure epochs') - - return parser - -def _add_experimental_args(parser): - group = parser.add_argument_group(title='experimental') - - group.add_argument('--model-spec', - type=str, default=None, nargs=2, - help='Specify the pair ' - 'that returns a spec to customize the transformer ' - 'layer implementation. For more details, check the' - '`transformer_layer.py` file that details the use ' - 'of spec based customization.') - return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py deleted file mode 100644 index 2be766e384..0000000000 --- a/megatron/checkpointing.py +++ /dev/null @@ -1,714 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Input/output checkpointing.""" - -import os -import random -import sys -import numpy as np - -import torch - -from megatron import update_num_microbatches -from megatron.core import mpu, tensor_parallel -from .global_vars import get_args -from .utils import (unwrap_model, - print_rank_0) - - -_CHECKPOINT_VERSION = None - - -def set_checkpoint_version(value): - global _CHECKPOINT_VERSION - if _CHECKPOINT_VERSION is not None: - assert _CHECKPOINT_VERSION == value, \ - "checkpoint versions do not match" - _CHECKPOINT_VERSION = value - - -def get_checkpoint_version(): - global _CHECKPOINT_VERSION - return _CHECKPOINT_VERSION - - -def check_checkpoint_args(checkpoint_args): - """Ensure fixed arguments for a model are the same for the input - arguments and the one retrieved from checkpoint.""" - args = get_args() - - def _compare(arg_name, old_arg_name=None, default=None): - if old_arg_name is not None: - ckpt_arg_name = old_arg_name - else: - ckpt_arg_name = arg_name - if default is not None: - checkpoint_value = getattr(checkpoint_args, ckpt_arg_name, default) - else: - checkpoint_value = getattr(checkpoint_args, ckpt_arg_name) - args_value = getattr(args, arg_name) - error_message = '{} value from checkpoint ({}) is not equal to the ' \ - 'input argument value ({}).'.format( - arg_name, checkpoint_value, args_value) - assert checkpoint_value == args_value, error_message - - _compare('num_layers') - _compare('hidden_size') - _compare('num_attention_heads') - _compare('add_position_embedding', default=True) - if args.vocab_file: - _compare('max_position_embeddings') - _compare('make_vocab_size_divisible_by') - _compare('padded_vocab_size') - _compare('tokenizer_type') - if args.data_parallel_random_init: - _compare('data_parallel_random_init') - if get_checkpoint_version() < 3.0: - _compare('tensor_model_parallel_size', - old_arg_name='model_parallel_size') - if get_checkpoint_version() >= 3.0: - _compare('tensor_model_parallel_size') - _compare('pipeline_model_parallel_size') - - -def ensure_directory_exists(filename): - """Build filename's path if it does not already exists.""" - dirname = os.path.dirname(filename) - os.makedirs(dirname, exist_ok = True) - - -def get_checkpoint_name(checkpoints_path, iteration, release=False, - pipeline_parallel=None, - tensor_rank=None, pipeline_rank=None, - expert_parallel=None, expert_rank=None): - """Determine the directory name for this rank's checkpoint.""" - if release: - directory = 'release' - else: - directory = 'iter_{:07d}'.format(iteration) - - # Use both the tensor and pipeline MP rank. - if pipeline_parallel is None: - pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) - if tensor_rank is None: - tensor_rank = mpu.get_tensor_model_parallel_rank() - if pipeline_rank is None: - pipeline_rank = mpu.get_pipeline_model_parallel_rank() - if expert_parallel is None: - expert_parallel = (mpu.get_expert_model_parallel_world_size() > 1) - if expert_rank is None: - expert_rank = mpu.get_expert_model_parallel_rank() - - # Use both the tensor and pipeline MP rank. If using the distributed - # optimizer, then the optimizer's path must additionally include the - # data parallel rank. - if not pipeline_parallel: - common_path = os.path.join(checkpoints_path, directory, - f'mp_rank_{tensor_rank:02d}') - else: - common_path = os.path.join(checkpoints_path, directory, - f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}') - - if expert_parallel: - common_path = common_path + f'_{expert_rank:03d}' - - return os.path.join(common_path, "model_optim_rng.pt") - - -def get_distributed_optimizer_checkpoint_name(model_checkpoint_name): - return os.path.join(os.path.dirname(model_checkpoint_name), - "distrib_optim.pt") - - -def find_checkpoint_rank_0(checkpoints_path, iteration, release=False): - """Finds the checkpoint for rank 0 without knowing if we are using - pipeline parallelism/expert parallelism or not. - - Since the checkpoint naming scheme changes if pipeline or expert - parallelism is present, we need to look for both naming schemes if - we don't know if the checkpoint has pipeline or expert parallelism. - """ - - # Look for checkpoint with no pipelining and no expert parallelism - filename = get_checkpoint_name(checkpoints_path, iteration, release, - pipeline_parallel=False, - tensor_rank=0, pipeline_rank=0, - expert_parallel=False, expert_rank=0) - if os.path.isfile(filename): - return filename - - # Look for checkpoint with no pipelining and expert parallelism - filename = get_checkpoint_name(checkpoints_path, iteration, release, - pipeline_parallel=False, - tensor_rank=0, pipeline_rank=0, - expert_parallel=True, expert_rank=0) - if os.path.isfile(filename): - return filename - - # Look for checkpoint with pipelining and no expert parallelism - filename = get_checkpoint_name(checkpoints_path, iteration, release, - pipeline_parallel=True, - tensor_rank=0, pipeline_rank=0, - expert_parallel=False, expert_rank=0) - if os.path.isfile(filename): - return filename - - # Look for checkpoint with pipelining and expert parallelism - filename = get_checkpoint_name(checkpoints_path, iteration, release, - pipeline_parallel=True, - tensor_rank=0, pipeline_rank=0, - expert_parallel=True, expert_rank=0) - if os.path.isfile(filename): - return filename - - return None, None - - -def get_checkpoint_tracker_filename(checkpoints_path): - - """Tracker file rescords the latest chckpoint during - training to restart from.""" - return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') - - -def read_metadata(tracker_filename): - # Read the tracker file and either set the iteration or - # mark it as a release checkpoint. - iteration = 0 - release = False - with open(tracker_filename, 'r') as f: - metastring = f.read().strip() - try: - iteration = int(metastring) - except ValueError: - release = metastring == 'release' - if not release: - print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( - tracker_filename)) - sys.exit() - assert iteration > 0 or release, 'error parsing metadata file {}'.format( - tracker_filename) - - # Get the max iteration retrieved across the ranks. - if torch.distributed.is_initialized(): - iters_cuda = torch.cuda.LongTensor([iteration]) - torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) - max_iter = iters_cuda[0].item() - - # We should now have all the same iteration. - # If not, print a warning and chose the maximum - # iteration across all ranks. - if iteration != max_iter: - rank = torch.distributed.get_rank() - print('WARNING: on rank {} found iteration {} in the ' - 'metadata while max iteration across the ranks ' - 'is {}, replacing it with max iteration.'.format( - rank, iteration, max_iter), flush=True) - else: - # When loading a checkpoint outside of training (for example, - # when editing it), we might not have torch distributed - # initialized, in this case, just assume we have the latest - max_iter = iteration - return max_iter, release - - -def get_rng_state(): - """ collect rng state across data parallel ranks """ - args = get_args() - rng_state = { - 'random_rng_state': random.getstate(), - 'np_rng_state': np.random.get_state(), - 'torch_rng_state': torch.get_rng_state(), - 'cuda_rng_state': torch.cuda.get_rng_state(), - 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()} - - rng_state_list = None - if torch.distributed.is_initialized() and \ - mpu.get_data_parallel_world_size() > 1 and \ - args.data_parallel_random_init: - rng_state_list = \ - [None for i in range(mpu.get_data_parallel_world_size())] - torch.distributed.all_gather_object( - rng_state_list, - rng_state, - group=mpu.get_data_parallel_group()) - else: - rng_state_list = [rng_state] - - return rng_state_list - - -def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): - """Save a model checkpoint.""" - args = get_args() - - # Only rank zero of the data parallel writes to the disk. - model = unwrap_model(model) - - print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( - iteration, args.save)) - - # Collect rng state across data parallel ranks. - rng_state = get_rng_state() - - # Checkpoint name. - checkpoint_name = get_checkpoint_name(args.save, iteration) - - # Save distributed optimizer's custom parameter state. - if args.use_distributed_optimizer and not args.no_save_optim and optimizer is not None: - optim_checkpoint_name = \ - get_distributed_optimizer_checkpoint_name(checkpoint_name) - ensure_directory_exists(optim_checkpoint_name) - optimizer.save_parameter_state(optim_checkpoint_name) - - # Collect args, model, RNG. - if not torch.distributed.is_initialized() \ - or mpu.get_data_modulo_expert_parallel_rank() == 0: - - # Arguments, iteration, and model. - state_dict = {} - state_dict['args'] = args - state_dict['checkpoint_version'] = 3.0 - state_dict['iteration'] = iteration - if len(model) == 1: - state_dict['model'] = model[0].state_dict_for_save_checkpoint() - else: - for i in range(len(model)): - mpu.set_virtual_pipeline_model_parallel_rank(i) - state_dict['model%d' % i] = \ - model[i].state_dict_for_save_checkpoint() - - # Optimizer stuff. - if not args.no_save_optim: - if optimizer is not None: - state_dict['optimizer'] = optimizer.state_dict() - if opt_param_scheduler is not None: - state_dict['opt_param_scheduler'] = \ - opt_param_scheduler.state_dict() - - # RNG states. - if not args.no_save_rng: - state_dict["rng_state"] = rng_state - - # Save. - ensure_directory_exists(checkpoint_name) - torch.save(state_dict, checkpoint_name) - - # Wait so everyone is done (necessary) - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \ - .format(iteration, args.save)) - - # And update the latest iteration - if not torch.distributed.is_initialized() \ - or torch.distributed.get_rank() == 0: - tracker_filename = get_checkpoint_tracker_filename(args.save) - with open(tracker_filename, 'w') as f: - f.write(str(iteration)) - - # Wait so everyone is done (not necessary) - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - -def _transpose_first_dim(t, num_splits, num_splits_first, model): - input_shape = t.size() - # We use a self_attention module but the values extracted aren't - # specific to self attention so should work for cross attention as well - while hasattr(model, 'module'): - model = model.module - attention_module = model.language_model.encoder.layers[0].self_attention - hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head - num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition - if num_splits_first: - """[num_splits * np * hn, h] - -->(view) [num_splits, np, hn, h] - -->(tranpose) [np, num_splits, hn, h] - -->(view) [np * num_splits * hn, h] """ - - intermediate_shape = \ - (num_splits, num_attention_heads_per_partition, - hidden_size_per_attention_head) + input_shape[1:] - - t = t.view(*intermediate_shape) - t = t.transpose(0, 1).contiguous() - else: - """[np * hn * num_splits, h] - -->(view) [np, hn, num_splits, h] - -->(tranpose) [np, num_splits, hn, h] - -->(view) [np * num_splits * hn, h] """ - - intermediate_shape = \ - (num_attention_heads_per_partition, - hidden_size_per_attention_head, num_splits) +\ - input_shape[1:] - - t = t.view(*intermediate_shape) - t = t.transpose(1, 2).contiguous() - t = t.view(*input_shape) - - return t - - -def fix_query_key_value_ordering(model, checkpoint_version): - """Fix up query/key/value matrix ordering if checkpoint - version is smaller than 2.0 - """ - if checkpoint_version < 2.0: - if isinstance(model, list): - assert len(model)==1 - model = model[0] - for name, param in model.named_parameters(): - if name.endswith(('.query_key_value.weight', '.query_key_value.bias')): - if checkpoint_version == 0: - fixed_param = _transpose_first_dim(param.data, 3, True, model) - elif checkpoint_version == 1.0: - fixed_param = _transpose_first_dim(param.data, 3, False, model) - else: - print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") - sys.exit() - param.data.copy_(fixed_param) - if name.endswith(('.key_value.weight', '.key_value.bias')): - if checkpoint_version == 0: - fixed_param = _transpose_first_dim(param.data, 2, True, model) - elif checkpoint_version == 1.0: - fixed_param = _transpose_first_dim(param.data, 2, False, model) - else: - print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") - sys.exit() - param.data.copy_(fixed_param) - print_rank_0(" succesfully fixed query-key-values ordering for" - " checkpoint version {}".format(checkpoint_version)) - - -def _load_base_checkpoint(load_dir, rank0=False): - """ Load the base state_dict from the given directory - - If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. - """ - - # Read the tracker file and set the iteration. - tracker_filename = get_checkpoint_tracker_filename(load_dir) - - # If no tracker file, return nothing - if not os.path.isfile(tracker_filename): - if not rank0: - print_rank_0('WARNING: could not find the metadata file {} '.format( - tracker_filename)) - print_rank_0(' will not load any checkpoints and will start from ' - 'random') - return None, "", False - - # Otherwise, read the tracker file and either set the iteration or - # mark it as a release checkpoint. - iteration, release = read_metadata(tracker_filename) - - # Checkpoint. - if rank0: - checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release) - else: - checkpoint_name = get_checkpoint_name(load_dir, iteration, release) - if release: - print_rank_0(f' loading release checkpoint from {load_dir}') - else: - print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}') - - # Load the checkpoint. - try: - state_dict = torch.load(checkpoint_name, map_location='cpu') - except ModuleNotFoundError: - from megatron.fp16_deprecated import loss_scaler - # For backward compatibility. - if not rank0: - print_rank_0(' > deserializing using the old code structure ...') - sys.modules['fp16.loss_scaler'] = sys.modules[ - 'megatron.fp16_deprecated.loss_scaler'] - sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ - 'megatron.fp16_deprecated.loss_scaler'] - state_dict = torch.load(checkpoint_name, map_location='cpu') - sys.modules.pop('fp16.loss_scaler', None) - sys.modules.pop('megatron.fp16.loss_scaler', None) - except BaseException as e: - print_rank_0('could not load the checkpoint') - print_rank_0(e) - sys.exit() - - return state_dict, checkpoint_name, release - - -def load_args_from_checkpoint(args, load_arg='load'): - """Set required arguments from the checkpoint specified in the - arguments. - - Will overwrite arguments that have a non-None default value, but - will leave any arguments that default to None as set. - - Returns the same args NameSpace with the new values added/updated. - - If no checkpoint is specified in args, or if the checkpoint is - there but invalid, the arguments will not be modified - - """ - load_dir = getattr(args, load_arg) - - if load_dir is None: - print_rank_0('No load directory specified, using provided arguments.') - return args - - state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=True) - - # Args. - if not state_dict: - print_rank_0('Checkpoint not found to provide arguments, using provided arguments.') - return args - - if 'args' not in state_dict: - print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.') - return args - - checkpoint_args = state_dict['args'] - checkpoint_version = state_dict.get('checkpoint_version', 0) - args.iteration = state_dict['iteration'] - - # One-off conversion for foundation models - if hasattr(checkpoint_args, 'disable_bias_linear'): - setattr(checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear')) - - def _set_arg(arg_name, old_arg_name=None, force=False): - if not force and getattr(args, arg_name, None) is not None: - return - - if old_arg_name is not None: - checkpoint_value = getattr(checkpoint_args, old_arg_name, None) - else: - checkpoint_value = getattr(checkpoint_args, arg_name, None) - - if checkpoint_value is not None: - print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint") - setattr(args, arg_name, checkpoint_value) - else: - print_rank_0(f"Checkpoint did not provide arguments {arg_name}") - - _set_arg('num_layers') - _set_arg('hidden_size') - _set_arg('ffn_hidden_size') - _set_arg('seq_length') - _set_arg('num_attention_heads') - _set_arg('num_query_groups', force=True) - _set_arg('group_query_attention', force=True) - _set_arg('kv_channels') - _set_arg('max_position_embeddings') - _set_arg('position_embedding_type', force=True) - _set_arg('add_position_embedding', force=True) - _set_arg('use_rotary_position_embeddings', force=True) - _set_arg('rotary_percent', force=True) - _set_arg('add_bias_linear', force=True) - _set_arg('swiglu', force=True) - _set_arg('untie_embeddings_and_output_weights', force=True) - _set_arg('apply_layernorm_1p', force=True) - _set_arg('normalization', force=True) - _set_arg('tokenizer_type') - _set_arg('padded_vocab_size') - if checkpoint_version < 3.0: - _set_arg('tensor_model_parallel_size', - 'model_parallel_size') - else: - _set_arg('tensor_model_parallel_size', force=True) - _set_arg('pipeline_model_parallel_size', force=True) - _set_arg('virtual_pipeline_model_parallel_size', force=True) - _set_arg('num_layers_per_virtual_pipeline_stage') - return args, checkpoint_args - - -def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True): - """Load a model checkpoint and return the iteration. - strict (bool): whether to strictly enforce that the keys in - :attr:`state_dict` of the checkpoint match the names of - parameters and buffers in model. - """ - args = get_args() - load_dir = getattr(args, load_arg) - - model = unwrap_model(model) - - state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=False) - - # Checkpoint not loaded. - if state_dict is None: - - # Conditionally exit at this point. - if args.exit_on_missing_checkpoint: - print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<") - torch.distributed.barrier() - sys.exit() - - # Iteration defaults to 0. - return 0 - - # Set checkpoint version. - set_checkpoint_version(state_dict.get('checkpoint_version', 0)) - - # Set iteration. - if args.finetune or release: - iteration = 0 - else: - try: - iteration = state_dict['iteration'] - except KeyError: - try: # Backward compatible with older checkpoints - iteration = state_dict['total_iters'] - except KeyError: - print_rank_0('A metadata file exists but unable to load ' - 'iteration from checkpoint {}, exiting'.format( - checkpoint_name)) - sys.exit() - - # Check arguments. - assert args.consumed_train_samples == 0 - assert args.consumed_valid_samples == 0 - if 'args' in state_dict and not args.finetune: - checkpoint_args = state_dict['args'] - check_checkpoint_args(checkpoint_args) - args.consumed_train_samples = getattr(checkpoint_args, - 'consumed_train_samples', 0) - update_num_microbatches(consumed_samples=args.consumed_train_samples) - args.consumed_valid_samples = getattr(checkpoint_args, - 'consumed_valid_samples', 0) - else: - print_rank_0('could not find arguments in the checkpoint ...') - - # Model. - if len(model) == 1: - model[0].load_state_dict(state_dict['model'], strict=strict) - else: - for i in range(len(model)): - mpu.set_virtual_pipeline_model_parallel_rank(i) - model[i].load_state_dict(state_dict['model%d' % i], strict=strict) - - # Fix up query/key/value matrix ordering if needed. - checkpoint_version = get_checkpoint_version() - print_rank_0(f' checkpoint version {checkpoint_version}') - fix_query_key_value_ordering(model, checkpoint_version) - - # Optimizer. - if not release and not args.finetune and not args.no_load_optim: - try: - # Load state dict. - if optimizer is not None: - optimizer.load_state_dict(state_dict['optimizer']) - - # Load distributed optimizer's custom parameter state. - if args.use_distributed_optimizer: - tracker_filename = get_checkpoint_tracker_filename(load_dir) - iteration, release = read_metadata(tracker_filename) - model_checkpoint_name = \ - get_checkpoint_name(load_dir, iteration, release) - optim_checkpoint_name = \ - get_distributed_optimizer_checkpoint_name( - model_checkpoint_name) - optimizer.load_parameter_state(optim_checkpoint_name) - - # Load scheduler. - if opt_param_scheduler is not None: - if 'lr_scheduler' in state_dict: # backward compatbility - opt_param_scheduler.load_state_dict(state_dict['lr_scheduler']) - else: - opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler']) - except KeyError: - print_rank_0('Unable to load optimizer from checkpoint {}. ' - 'Specify --no-load-optim or --finetune to prevent ' - 'attempting to load the optimizer state, ' - 'exiting ...'.format(checkpoint_name)) - sys.exit() - else: - if (args.fp16 or args.bf16) and optimizer is not None: - optimizer.reload_model_params() - - # rng states. - if not release and not args.finetune and not args.no_load_rng: - try: - if 'rng_state' in state_dict: - # access rng_state for data parallel rank - if args.data_parallel_random_init: - rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()] - else: - rng_state = state_dict['rng_state'][0] - random.setstate(rng_state['random_rng_state']) - np.random.set_state(rng_state['np_rng_state']) - torch.set_rng_state(rng_state['torch_rng_state']) - torch.cuda.set_rng_state(rng_state['cuda_rng_state']) - # Check for empty states array - if not rng_state['rng_tracker_states']: - raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states( - rng_state['rng_tracker_states']) - else: # backward compatability - random.setstate(state_dict['random_rng_state']) - np.random.set_state(state_dict['np_rng_state']) - torch.set_rng_state(state_dict['torch_rng_state']) - torch.cuda.set_rng_state(state_dict['cuda_rng_state']) - # Check for empty states array - if not state_dict['rng_tracker_states']: - raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states( - state_dict['rng_tracker_states']) - except KeyError: - print_rank_0('Unable to load rng state from checkpoint {}. ' - 'Specify --no-load-rng or --finetune to prevent ' - 'attempting to load the rng state, ' - 'exiting ...'.format(checkpoint_name)) - sys.exit() - - # Some utilities want to load a checkpoint without distributed being initialized - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - print_rank_0(f' successfully loaded checkpoint from {args.load} ' - f'at iteration {iteration}') - - return iteration - - -def load_biencoder_checkpoint(model, only_query_model=False, - only_context_model=False, custom_load_path=None): - """ - selectively load retrieval models for indexing/retrieving - from saved checkpoints - """ - - args = get_args() - - model = unwrap_model(model) - - load_path = custom_load_path if custom_load_path is not None else args.load - - tracker_filename = get_checkpoint_tracker_filename(load_path) - with open(tracker_filename, 'r') as f: - iteration = int(f.read().strip()) - - checkpoint_name = get_checkpoint_name(load_path, iteration, - args.use_distributed_optimizer, - release=False) - - if mpu.get_data_parallel_rank() == 0: - print('global rank {} is loading checkpoint {}'.format( - torch.distributed.get_rank(), checkpoint_name)) - - state_dict = torch.load(checkpoint_name, map_location='cpu') - ret_state_dict = state_dict['model'] - - if only_query_model: - ret_state_dict.pop('context_model') - if only_context_model: - ret_state_dict.pop('query_model') - - assert len(model) == 1 - model[0].load_state_dict(ret_state_dict) - torch.distributed.barrier() - - if mpu.get_data_parallel_rank() == 0: - print(' successfully loaded {}'.format(checkpoint_name)) - - return model diff --git a/megatron/core/QuickStart.md b/megatron/core/QuickStart.md new file mode 100644 index 0000000000..6deb1a5f76 --- /dev/null +++ b/megatron/core/QuickStart.md @@ -0,0 +1,250 @@ +## Quick Start + +The following guide is a short getting started guide for Megatron Core. In it you: + +* Initialize Megatron Core on 2 GPUS. +* Build a GPT model with tensor model parallel size 2, pipeline parallel size 1 +* Train it for a five iterations using Megatron Core schedules +* Save the model using the distributed checkpointing format +* Load the model saved above. + +**NOTE:** The following sample was tested using Megatron Core version 0.8.0 and NGC PyTorch Container version 24.02. + +### Environment Setup + +``` +docker run --ipc=host --shm-size=512m --gpus 2 -it nvcr.io/nvidia/pytorch:24.02-py3 + +git clone https://github.com/NVIDIA/Megatron-LM.git && cd Megatron-LM +``` +
+ +### Writing Your First Training Loop + +In the following steps you create a sample GPT model split across tensors (Tensor model parallel) on 2 GPUS, and run a forward pass through it using a MockGPT dataset helper class that we created in Megatron Core. + +
+ +**NOTE:** All of the following steps are in the [run_simple_mcore_train_loop.py](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/run_simple_mcore_train_loop.py) script. + +To run the ``run_simple_mcore_train_loop.py`` script: + +``` +PYTHONPATH=$PYTHON_PATH:./megatron torchrun --nproc-per-node 2 examples/run_simple_mcore_train_loop.py +``` + +
+ +**STEP 1 - Initialize Distributed Training and Model Parallel Setup** + +The following utility, when called, initializes your distributed setup. + +```python +import os +import torch +from megatron.core import parallel_state + +def initialize_distributed(tensor_model_parallel_size = 1, pipeline_model_parallel_size = 1): + # Torch setup for distributed training + rank = int(os.environ['LOCAL_RANK']) + world_size = torch.cuda.device_count() + torch.cuda.set_device(rank) + torch.distributed.init_process_group(world_size=world_size, rank=rank) + + # Megatron core distributed training initialization + parallel_state.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size) +``` +
+ +**STEP 2 - GPT Model Setup** + +In this step, you create a GPT model. For a list of other configurations that you can pass into the model open and review [transformer_config.py](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/transformer/transformer_config.py). + +``` +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec + +def model_provider(): + """Build the model.""" + + transformer_config = TransformerConfig( + num_layers=2, + hidden_size=12, + num_attention_heads=4, + use_cpu_initialization=True, + pipeline_dtype=torch.float32) + + gpt_model = GPTModel( + config=transformer_config, + transformer_layer_spec=get_gpt_layer_local_spec(), + vocab_size=100, + max_sequence_length=64) + + return gpt_model +``` +
+ +**STEP 3 - GPT Mock Dataset Setup** + +In the following step, you explore the mock dataset utility. + +* To train the model using your data, use the GPTDataset class in [gpt_dataset.py](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/datasets/gpt_dataset.py). + +* To find more information about Megatron Core data pipeline, see the [data pipeline readme.md](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/datasets/readme.md?ref_type=heads). + +``` +import torch +from torch.utils.data import DataLoader + +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset +from megatron.training.tokenizer.tokenizer import _NullTokenizer +from megatron.core.datasets.utils import compile_helpers + +_SEQUENCE_LENGTH = 64 + +def get_train_data_iterator(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + compile_helpers() + torch.distributed.barrier() + else: + compile_helpers() + + config = GPTDatasetConfig( + random_seed=0, + sequence_length=_SEQUENCE_LENGTH, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + tokenizer=_NullTokenizer(vocab_size=_SEQUENCE_LENGTH), + ) + + datasets = BlendedMegatronDatasetBuilder( + MockGPTDataset, [1000, None, None], lambda: True, config + ).build() + + train_dataloader = DataLoader(datasets[0], batch_size=8, shuffle=True) + + train_iterator = iter(train_dataloader) + + return train_iterator + +``` +
+ +**STEP 4 - Forward Step Function** + +Megatron Core uses [schedules.py](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/pipeline_parallel/schedules.py) to run the model. It is sufficient to define a forward step function, which takes as input the data iterator and the model and produces as output the output tensor and a loss function. + +```python +from functools import partial + +def forward_step_func(data_iterator, model): + + def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + # If you have data parallel reduce loss across data parallel groups. + # If pipeline parallel, loss computation is done only in last stage. + + return loss, {'lm loss': loss} + + data = next(data_iterator) + tokens = data['tokens'].to(device) + attention_mask = data['attention_mask'].to(device) + position_ids = data['position_ids'].to(device) + labels = data['labels'].to(device) + loss_mask = data['loss_mask'].to(device) + + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) +``` +
+ +**STEP 5 - Load and Save Distributed Checkpoint** + +Megatron Core uses distributed checkpoints for loading and saving models. This gives you the flexibility to convert the model from one model parallel setting to another when you load a model. For example, a model trained with tensor parallel size 2, can be loaded again as tensor model parallel size 4, and so forth. + +```python +from megatron.core import dist_checkpointing + +def save_distributed_checkpoint(checkpoint_path, gpt_model): + sharded_state_dict = gpt_model.sharded_state_dict(prefix='') + dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path) + +def load_distributed_checkpoint(checkpoint_path, gpt_model): + sharded_state_dict=gpt_model.sharded_state_dict(prefix='') + checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path) + gpt_model.load_state_dict(checkpoint) + return gpt_model +``` +
+ +**STEP 6 - Main Function** + +The following code snippet is the main function that needs to go into your script. It runs the model for 5 iterations, saves the model, and loads the data model. + +```python +from pathlib import Path +from torch.optim import Adam +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + +if __name__ == "__main__": + initialize_distributed(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) + model_parallel_cuda_manual_seed(123) + + gpt_model = model_provider() + device = torch.device("cuda") + gpt_model.to(device) + + optim = Adam(gpt_model.parameters()) + + train_iterator = get_train_data_iterator() + + forward_backward_func = get_forward_backward_func() + + # Running the model for 5 iterations + for _ in range(5): + optim.zero_grad() + + losses_reduced = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=train_iterator, + model=gpt_model, + num_microbatches=1, + seq_length=64, + micro_batch_size=8, + decoder_seq_length=64, + forward_only=False) + + optim.step() + + print(f'Losses reduced : {losses_reduced}') + + # Saving the model + save_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path='/workspace/ckpt') + + # Loading the model + gpt_model = load_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path='/workspace/ckpt') + gpt_model.to(device) + print('Successfully loaded the model') +``` +
+ + + +### Extending Further + +The example you explored here is a basic training loop in Megatron Core. To review more advanced examples, explore [pretrain_gpt.py]. ``pretrain_gpt.py`` has more complex training loops that includes the following and other Megatron Core features: + +* pipeline parallel +* context parallel +* rope embeddings +* mixture of experts diff --git a/megatron/core/README.md b/megatron/core/README.md index 0c8c61738d..38970b0c47 100644 --- a/megatron/core/README.md +++ b/megatron/core/README.md @@ -1 +1,14 @@ -Megatron Core is a library for efficient and scalable training of transformer based models. +# Megatron-Core + +Megatron-Core is an open-source PyTorch-based library that contains GPU-optimized techniques and cutting-edge system-level optimizations. It abstracts them into composable and modular APIs, allowing full flexibility for developers and model researchers to train custom transformers at-scale on NVIDIA accelerated computing infrastructure. This library is compatible with all NVIDIA Tensor Core GPUs, including FP8 acceleration support for [NVIDIA Hopper architectures](https://www.nvidia.com/en-us/data-center/technologies/hopper-architecture/). + +Megatron-Core offers core building blocks such as attention mechanisms, transformer blocks and layers, normalization layers, and embedding techniques. Additional functionality like activation re-computation, distributed checkpointing is also natively built-in to the library. The building blocks and functionality are all GPU optimized, and can be built with advanced parallelization strategies for optimal training speed and stability on NVIDIA Accelerated Computing Infrastructure. Another key component of the Megatron-Core library includes advanced model parallelism techniques (tensor, sequence, pipeline, context, and MoE expert parallelism). + +Megatron-Core can be used with [NVIDIA NeMo](https://www.nvidia.com/en-us/ai-data-science/products/nemo/), an enterprise-grade AI platform. Alternatively, you can explore Megatron-Core with the native PyTorch training loop [here](https://github.com/NVIDIA/Megatron-LM/tree/main/examples). Visit [Megatron-Core documentation](https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html) to learn more. + +## Quick links + +- [Benchmark using NVIDIA NeMo](https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html#performance-benchmarks) +- [Multimodal example (LLaVA training pipeline)](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/multimodal) +- [Mixture-of-Experts](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/transformer/moe) +- [Training Mamba-based Language Models](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/mamba) diff --git a/megatron/core/README_STRAGGLER.md b/megatron/core/README_STRAGGLER.md new file mode 100644 index 0000000000..fe9062c851 --- /dev/null +++ b/megatron/core/README_STRAGGLER.md @@ -0,0 +1,93 @@ +## StragglerDetector for a TP Group + +The file `megatron/core/utils.py` has a class named `StragglerDetector` which supports Python Contexts. +It can be used to find straggling TP group based on the RTT of the ranks in the TP Group. It also collects +Power/Temp/Utilization for GPUs, which can additionally be used to narrow down to the exact GPU in the TP Group, +assuming the straggling was caused by hardware anomaly in a given GPU.
+This class supports collecting timing events for various steps of a given iteration. It +keeps collecting such timing events on a per rank basis, and when the reporter is invoked +during a logging interval, it computes the min and max of certain metric across all +ranks and logs the observed metric and the rank as follows + +``` + 0: INFO:megatron.core.utils:[2024-03-14 23:07:56] | MnRtt/Rnk: 3453.08ms/8 | MxRtt/Rnk: 3468.20ms/0 | MnPwr/Rnk: 601796W/8 | MxPwr/Rnk: 683801W/18 | MnTmp/Rnk: 52C/0 | MxTmp/Rnk: 65C/21 | MnUtl/Rnk: 97%/8 | MxUtl/Rnk: 100%/6 | MnClk/Rnk: 1950MHz/28 | MxClk/Rnk: 1980MHz/0 | MnDRtt/Rnk: 14.27ms/23 | MxDRtt/Rnk: 34.65ms/3 | MnEtpt/Rnk: 296.02TF/0 | MxEtpt/Rnk: 297.32TF/8 +``` +
+ +### Description of the metrics + +Each metric is prefixed with `Mn` or `Mx` to represent `Minimum` or `Maximum`. Each metric is also suffixed with the rank where the metric was measured. The metrics are averaged over the logging interval. Between the prefix and the rank is the name of the metric as follows + +- Rtt : RoundTrip Time (time spent in all the traced ops per iteration) +- Pwr : GPU Power +- Tmp : GPU Temperature +- Utl : GPU Utilization +- Clk : GPU Clock +- DRtt: get_batch latency +- Etpt: Estimated throughput. This is derived from actual computed throughput dividied by Rtt. Since we do not collect timing for backward pass, the value is further divided by three to come up with estimated throughput. +
+ +### Command Line activation +To start using the StragglerDetector, need to pass the following argument `--log-straggler`. It optionally also takes two additional parameters. Default disabled +- `--disable-straggler-on-startup` - whether to keept the StragglerDetector disabled on startup and enable later. Default enabled +- `--straggler-ctrlr-port` - The StragglerDetector can toggle between on/off just by sending `curl Rank0Host:port`. Default port is 65535. Every time it is turned +- `--straggler-minmax-count` - If set to > 1 (N), it prints N Top and Bottom Etpt/Rank pairs as shown below +``` + 0: INFO:megatron.core.utils:^^^^ Bottom 4 Ranks with lowest Etpt(TF): 296.02/0, 296.17/2, 296.23/1, 296.23/4, + 0: INFO:megatron.core.utils:^^^^ Top 4 Ranks with highest Etpt(TF): 297.28/15, 297.28/11, 297.32/12, 297.32/8, +``` +
+ +### Programming the StragglerDetector +The StragglerDetector class supports context, and its implementation is a Singleton. +- Initialization + +``` + # initialization, where StragglerDetector will be used + from megatron.core.utils import StragglerDetector + stimer = StragglerDetector() +``` + +- One time for each rank + +``` + # one time before the training loop starts + stimer.configure(world, rank, enabled=True, port=65545) + + # Arguments to configure + # world : World Size + # rank : The rank of this trainer + # mmcnt : (Optional) Number of ranks to print for showing Min/Max Etpt + # amp : (Optional) Set to 3.0 if we only use timers in fwd pass + # port : (Optional) control port, useful only for rank-0 + # prefill : (Optional) howmany Events to pre-populate + # enabled : (Optional) whether or not collection is enabled on startup +``` + +- To Capture time + +``` + # whereever timing need to be captured + with stimer: + do_operation() + + # special case for get_batch + with stimer(bdata=True): + input,... = get_batch(iterator,...) +``` + +- Logging in main training loop + +``` + # logging + total_flops = 0.0 + iteration = 0 + # inside the main training loop + while training: + iteration += 1 + do_step() + total_flops += get_computed_flops() + if iteration % log_interval: + stimer.report(total_flops, log_interval) + total_flops = 0.0 +``` diff --git a/megatron/core/__init__.py b/megatron/core/__init__.py index 7457708229..0eccb1d02e 100644 --- a/megatron/core/__init__.py +++ b/megatron/core/__init__.py @@ -1,9 +1,24 @@ -import megatron.core.parallel_state +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import megatron.core.tensor_parallel import megatron.core.utils +from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel from megatron.core.inference_params import InferenceParams from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.package_info import ( + __contact_emails__, + __contact_names__, + __description__, + __download_url__, + __homepage__, + __keywords__, + __license__, + __package_name__, + __repository_url__, + __shortversion__, + __version__, +) +from megatron.core.timers import Timers # Alias parallel_state as mpu, its legacy name mpu = parallel_state @@ -15,4 +30,5 @@ "DistributedDataParallel", "InferenceParams", "ModelParallelConfig", + "Timers", ] diff --git a/megatron/core/config_logger.py b/megatron/core/config_logger.py new file mode 100644 index 0000000000..231a0226be --- /dev/null +++ b/megatron/core/config_logger.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import dataclasses +import json +import os + +import torch +import torch.nn as nn + +from megatron.core import parallel_state + + +def get_config_logger_path(config): + return getattr(config, 'config_logger_dir', '') + + +def has_config_logger_enabled(config): + return get_config_logger_path(config) != '' + + +# For each prefix, holds a counter and increases it every time we dump with this +# prefix. +__config_logger_path_counts = {} + + +def get_path_count(path): + """ + keeps tracks of number of times we've seen the input `path` and return count-1 + """ + global __config_logger_path_counts + if not path in __config_logger_path_counts: + __config_logger_path_counts[path] = 0 + count = __config_logger_path_counts[path] + __config_logger_path_counts[path] += 1 + return count + + +def get_path_with_count(path): + """ + calls get_path_count and appends returned value to path + """ + return f'{path}.iter{get_path_count(path)}' + + +class JSONEncoderWithMcoreTypes(json.JSONEncoder): + def default(self, o): + if type(o).__name__ in ['function', 'ProcessGroup']: + return str(o) + if type(o).__name__ in ['dict', 'OrderedDict']: + return {k: self.default(v) for k, v in o.items()} + if type(o).__name__ in ['list', 'ModuleList']: + return [self.default(val) for val in o] + if type(o).__name__ == 'UniqueDescriptor': + return { + attr: self.default(getattr(o, attr)) + for attr in filter(lambda x: not x.startswith('__'), dir(o)) + } + if type(o) is torch.dtype: + return str(o) + # if it's a Float16Module, add "Float16Module" to the output dict + if type(o).__name__ == 'Float16Module': + return {'Float16Module': {'module': self.default(o.module)}} + # If it's a nn.Module subchild, either print its children or itself if leaf. + if issubclass(type(o), nn.Module): + if len(getattr(o, '_modules', {})) > 0: + return {key: self.default(val) for key, val in o._modules.items()} + else: + return str(o) + if type(o).__name__ in ['ABCMeta', 'type', 'AttnMaskType']: + return str(o) + if dataclasses.is_dataclass(o) or type(o).__name__ in ['ModuleSpec', 'TransformerConfig']: + return dataclasses.asdict(o) + try: + return super().default(o) + except: + return str(o) + + +def log_config_to_disk(config, dict_data, prefix=''): + """ + Encodes the input dict (dict_data) using the JSONEncoderWithMcoreTypes + and dumps to disk, as specified via path + """ + path = get_config_logger_path(config) + assert path is not None, 'Expected config_logger_dir to be non-empty in config.' + + if 'self' in dict_data: + if prefix == '': + prefix = type(dict_data['self']).__name__ + del dict_data['self'] + + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + + rank = parallel_state.get_all_ranks() + path = get_path_with_count(os.path.join(path, f'{prefix}.rank_{rank}')) + if type(dict_data).__name__ == 'OrderedDict': + torch.save(dict_data, f'{path}.pth') + else: + with open(f'{path}.json', 'w') as fp: + json.dump(dict_data, fp, cls=JSONEncoderWithMcoreTypes) + + +__all__ = ['has_config_logger_enabled', 'log_config_to_disk'] diff --git a/megatron/core/datasets/Makefile b/megatron/core/datasets/Makefile index 8f9db76866..e745f52399 100644 --- a/megatron/core/datasets/Makefile +++ b/megatron/core/datasets/Makefile @@ -1,9 +1,13 @@ CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color CPPFLAGS += $(shell python3 -m pybind11 --includes) -LIBNAME = helpers + +LIBNAME = helpers_cpp LIBEXT = $(shell python3-config --extension-suffix) -default: $(LIBNAME)$(LIBEXT) +OUT = $(LIBNAME)$(LIBEXT) +SRC = helpers.cpp + +default: $(OUT) -%$(LIBEXT): %.cpp +$(OUT): $(SRC) $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/megatron/core/datasets/bert_dataset.py b/megatron/core/datasets/bert_dataset.py new file mode 100644 index 0000000000..78ae2edf62 --- /dev/null +++ b/megatron/core/datasets/bert_dataset.py @@ -0,0 +1,192 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +import numpy + +from megatron.core.datasets.indexed_dataset import IndexedDataset +from megatron.core.datasets.masked_dataset import ( + MaskedWordPieceDataset, + MaskedWordPieceDatasetConfig, +) +from megatron.core.datasets.utils import Split + + +@dataclass +class BERTMaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig): + """Configuration object for Megatron Core BERT WordPiece datasets""" + + classification_head: bool = None + """Option to perform the next sequence prediction during sampling""" + + def __post_init__(self) -> None: + """Do asserts and set fields post init""" + super().__post_init__() + + assert self.classification_head is not None + + +class BERTMaskedWordPieceDataset(MaskedWordPieceDataset): + """The BERT dataset that assumes WordPiece tokenization + + Args: + indexed_dataset (IndexedDataset): The IndexedDataset around which to build the MegatronDataset + + dataset_path (str): The real path on disk to the dataset, for bookkeeping + + indexed_indices (numpy.ndarray): The set of the documents indices to expose + + num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch. + + index_split (Split): The indexed_indices Split + + config (BERTMaskedWordPieceDatasetConfig): The config + """ + + def __init__( + self, + indexed_dataset: IndexedDataset, + dataset_path: str, + indexed_indices: numpy.ndarray, + num_samples: Optional[int], + index_split: Split, + config: BERTMaskedWordPieceDatasetConfig, + ) -> None: + super().__init__( + indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config + ) + + self.token_lookup = list(self.config.tokenizer.inv_vocab.keys()) + # Account for the single and two token ids + self.sample_index = self._build_sample_index( + self.config.sequence_length - 3, 2 if self.config.classification_head else 1 + ) + + @staticmethod + def _key_config_attributes() -> List[str]: + """Inherited method implementation + + Returns: + List[str]: The key config attributes + """ + return super( + BERTMaskedWordPieceDataset, BERTMaskedWordPieceDataset + )._key_config_attributes() + ["classification_head"] + + def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: + """Abstract method implementation + + Args: + idx (int): The index into the dataset + + Returns: + Dict[str, Union[int, numpy.ndarray]]: The + """ + idx_beg, idx_end, target_sequence_length = self.sample_index[idx] + sample = [self.dataset[i] for i in range(idx_beg, idx_end)] + numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32) + + assert target_sequence_length <= self.config.sequence_length + + # Split the sample into contiguous subsegments A and B + pivot = len(sample) + is_next_random = False + if self.config.classification_head: + assert len(sample) > 1, "the sample must contain at least two sentences" + pivot = 1 + if len(sample) >= 3: + pivot = numpy_random_state.randint(low=1, high=len(sample)) + is_next_random = numpy_random_state.random() < 0.5 + split_A = [] + for sample_a in sample[:pivot]: + split_A.extend(sample_a) + split_B = [] + for sample_b in sample[pivot:]: + split_B.extend(sample_b) + if is_next_random: + split_A, split_B = split_B, split_A + + # Trim the subsegments from either end to a desired joint length + length_A = len(split_A) + length_B = len(split_B) + if length_A + length_B <= target_sequence_length: + truncated = False + else: + while length_A + length_B > target_sequence_length: + split = split_A if length_A > length_B else split_B + if numpy_random_state.random() < 0.5: + del split[0] + else: + del split[-1] + length_A = len(split_A) + length_B = len(split_B) + truncated = True + + # Merge the subsegments and create the token assignment labels + tokens = [self.config.tokenizer.cls, *split_A, self.config.tokenizer.sep] + assignments = [0 for _ in range(1 + len(split_A) + 1)] + if split_B: + tokens += [*split_B, self.config.tokenizer.sep] + assignments += [1 for _ in range(len(split_B) + 1)] + + # Masking + tokens, masked_positions, masked_labels, _, _ = self._create_masked_lm_predictions( + tokens, target_sequence_length, numpy_random_state + ) + + # Pad the sequences and convert to NumPy + length_toks = len(tokens) + length_pads = self.config.sequence_length - length_toks + assert length_pads >= 0 + + tokens = numpy.array(tokens, dtype=numpy.int64) + tokens = numpy.pad(tokens, (0, length_pads), constant_values=self.config.tokenizer.pad) + + assignments = numpy.array(assignments, dtype=numpy.int64) + assignments = numpy.pad( + assignments, (0, length_pads), constant_values=self.config.tokenizer.pad + ) + + # Get the padding mask + mask_pads = numpy.ones(length_toks, dtype=numpy.int64) + mask_pads = numpy.pad( + mask_pads, (0, length_pads), constant_values=self.config.tokenizer.pad + ) + + # Mask the labels + labels = numpy.zeros(self.config.sequence_length, dtype=numpy.int64) - 1 + labels[masked_positions] = masked_labels + + # Get the loss mask + mask_loss = numpy.zeros(self.config.sequence_length, dtype=numpy.int64) + mask_loss[masked_positions] = 1 + + return { + "text": tokens, + "types": assignments, + "labels": labels, + "is_random": int(is_next_random), + "padding_mask": mask_pads, + "loss_mask": mask_loss, + "truncated": int(truncated), + } + + def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> Optional[int]: + """Abstract method implementation + + 80% of the time, replace the token id with mask token id. 10% of the time, replace token id + with a random token id from the vocabulary. 10% of the time, do nothing. + + Args: + numpy_random_state (RandomState): The NumPy random state + + Returns: + Optional[int]: The replacement token id or None + """ + if numpy_random_state.random() < 0.8: + return self.config.tokenizer.mask + else: + if numpy_random_state.random() >= 0.5: + return self.token_lookup[numpy_random_state.randint(0, len(self.token_lookup))] + return None diff --git a/megatron/core/datasets/blended_dataset.py b/megatron/core/datasets/blended_dataset.py index 89f3bbc9e5..4628686e5b 100644 --- a/megatron/core/datasets/blended_dataset.py +++ b/megatron/core/datasets/blended_dataset.py @@ -6,14 +6,15 @@ import os import time from collections import OrderedDict -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy import torch from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig from megatron.core.datasets.megatron_dataset import MegatronDataset -from megatron.core.datasets.utils import log_single_rank, normalize +from megatron.core.datasets.utils import normalize +from megatron.core.utils import log_single_rank logger = logging.getLogger(__name__) @@ -26,11 +27,12 @@ class BlendedDataset(torch.utils.data.Dataset): Args: datasets (List[MegatronDataset]): The MegatronDataset instances to blend - weights (List[float]): The weights which determines the dataset blend ratios + weights (List[Union[int, float]]): The weights that determine the dataset blend ratios - size (int): The number of samples to draw from the blend + size (Optional[int]): The number of samples to draw from the blend. If None, for each + dataset index idx draw exactly weights[idx] samples from datasets[idx]. - config (BlendedMegatronDatasetConfig): The config object which informs dataset creation + config (BlendedMegatronDatasetConfig): The config Raises: RuntimeError: When the dataset has fewer or more samples than 'size' post-initialization @@ -39,14 +41,18 @@ class BlendedDataset(torch.utils.data.Dataset): def __init__( self, datasets: List[MegatronDataset], - weights: List[float], - size: int, + weights: List[Union[int, float]], + size: Optional[int], config: BlendedMegatronDatasetConfig, ) -> None: - assert len(datasets) < 32767 assert len(datasets) == len(weights) - assert numpy.isclose(sum(weights), 1.0) + assert len(datasets) < 32767 assert all(map(lambda _: type(_) == type(datasets[0]), datasets)) + assert all(map(lambda _: _.index_split == datasets[0].index_split, datasets)) + assert all(map(lambda _: _ > 0, weights)) + assert all(map(lambda _: type(_) == type(weights[0]), weights)) + if size is None and isinstance(weights[0], float): + assert all(map(lambda _: _ == int(_), weights)) # Alert user to unnecessary blending if len(datasets) == 1: @@ -54,10 +60,11 @@ def __init__( logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset" ) - # Redundant normalization for bitwise identical comparison with Megatron-LM - weights = normalize(weights) + if size is not None: + weights = normalize(weights) self.datasets = datasets + self.split = self.datasets[0].index_split self.weights = weights self.size = size self.config = config @@ -65,34 +72,28 @@ def __init__( unique_identifiers = OrderedDict() unique_identifiers["class"] = type(self).__name__ unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets] + unique_identifiers["split"] = self.split.name unique_identifiers["weights"] = self.weights unique_identifiers["size"] = self.size - self.unique_description = json.dumps(unique_identifiers, indent=4) + self.unique_description = json.dumps( + unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers + ) self.unique_description_hash = hashlib.md5( self.unique_description.encode("utf-8") ).hexdigest() - self.dataset_index, self.dataset_sample_index = self._build_indices() + self.built_anew_on_cache_miss = False - # Check size - _ = self[self.size - 1] - try: - _ = self[self.size] - raise RuntimeError(f"{type(self).__name__} size is improperly bounded") - except IndexError: - log_single_rank(logger, logging.INFO, f"> {type(self).__name__} length: {len(self)}") + self.dataset_index, self.dataset_sample_index = self._build_indices() def __len__(self) -> int: - return self.size + return self.dataset_index.shape[0] def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: dataset_id = self.dataset_index[idx] dataset_sample_id = self.dataset_sample_index[idx] - return { - "dataset_id": dataset_id, - **self.datasets[dataset_id][dataset_sample_id], - } + return {"dataset_id": dataset_id, **self.datasets[dataset_id][dataset_sample_id]} def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]: """Build and optionally cache the dataset index and the dataset sample index @@ -104,11 +105,12 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]: Returns: Tuple[numpy.ndarray, numpy.ndarray]: The dataset index and the dataset sample index """ - path_to_cache = getattr(self.config, "path_to_cache") + path_to_cache = self.config.path_to_cache if path_to_cache: get_path_to = lambda suffix: os.path.join( - path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}" + path_to_cache, + f"{self.unique_description_hash}-{type(self).__name__}-{self.split.name}-{suffix}", ) path_to_description = get_path_to("description.txt") path_to_dataset_index = get_path_to("dataset_index.npy") @@ -124,8 +126,9 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]: if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0): log_single_rank( - logger, logging.INFO, f"Build and save the {type(self).__name__} indices", + logger, logging.INFO, f"Build and save the {type(self).__name__} indices" ) + self.built_anew_on_cache_miss = True # Build the dataset and dataset sample indexes log_single_rank( @@ -134,16 +137,24 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]: t_beg = time.time() from megatron.core.datasets import helpers - dataset_index = numpy.zeros(self.size, dtype=numpy.int16) - dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64) - helpers.build_blending_indices( - dataset_index, - dataset_sample_index, - self.weights, - len(self.datasets), - self.size, - _VERBOSE, - ) + if self.size is not None: + dataset_index = numpy.zeros(self.size, dtype=numpy.int16) + dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64) + helpers.build_blending_indices( + dataset_index, + dataset_sample_index, + self.weights, + len(self.datasets), + self.size, + _VERBOSE, + ) + else: + size = sum(self.weights) + dataset_index = numpy.zeros(size, dtype=numpy.int16) + dataset_sample_index = numpy.zeros(size, dtype=numpy.int64) + helpers.build_exhaustive_blending_indices( + dataset_index, dataset_sample_index, self.weights, len(self.datasets) + ) if path_to_cache: os.makedirs(path_to_cache, exist_ok=True) @@ -157,7 +168,7 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]: log_single_rank( logger, logging.WARNING, - "Unable to save the indexes because path_to_cache is None", + f"Cannot save the {type(self).__name__} indexes because path_to_cache is None", ) t_end = time.time() diff --git a/megatron/core/datasets/blended_megatron_dataset_builder.py b/megatron/core/datasets/blended_megatron_dataset_builder.py index 3dee4e4696..69b2775e1b 100644 --- a/megatron/core/datasets/blended_megatron_dataset_builder.py +++ b/megatron/core/datasets/blended_megatron_dataset_builder.py @@ -2,20 +2,28 @@ import logging import math -from typing import Any, List, Optional, Tuple, Type, Union +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Iterable, List, Optional, Type, Union import numpy import torch from megatron.core.datasets.blended_dataset import BlendedDataset from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig -from megatron.core.datasets.indexed_dataset import MMapIndexedDataset -from megatron.core.datasets.megatron_dataset import MegatronDataset +from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset from megatron.core.datasets.utils import Split, normalize +from megatron.core.parallel_state import get_virtual_pipeline_model_parallel_rank +from megatron.core.utils import log_single_rank logger = logging.getLogger(__name__) -DistributedDataset = Union[BlendedDataset, MegatronDataset, MMapIndexedDataset] +MidLevelDataset = MegatronDataset + +TopLevelDataset = Union[BlendedDataset, MidLevelDataset] + +DistributedDataset = Union[ + TopLevelDataset, MidLevelDataset, LowLevelDataset, torch.utils.data.Dataset +] class BlendedMegatronDatasetBuilder(object): @@ -24,219 +32,491 @@ class BlendedMegatronDatasetBuilder(object): Args: cls (Type[MegatronDataset]): The class to instantiate, must inherit from MegatronDataset - sizes (List[int]): The minimum number of total samples to draw from each split, varies - with blend + sizes (List[Optional[int]]): The minimum total number of samples to draw, or None, per split + + is_built_on_rank (Callable): A callable which returns True if the dataset should be built on + the current rank and False otherwise. It should be Megatron Core parallelism aware i.e. + global rank, local group rank, and virtual rank may inform its return value. config (BlendedMegatronDatasetConfig): The config object which informs dataset creation """ def __init__( - self, cls: Type[MegatronDataset], sizes: List[int], config: BlendedMegatronDatasetConfig, + self, + cls: Type[MidLevelDataset], + sizes: List[int], + is_built_on_rank: Callable, + config: BlendedMegatronDatasetConfig, ): self.cls = cls self.sizes = sizes + self.is_built_on_rank = is_built_on_rank self.config = config - def build(self) -> List[Optional[Union[BlendedDataset, MegatronDataset]]]: + log_single_rank( + logger, + logging.INFO, + f"Building {cls.__name__} splits with sizes={self.sizes} and config={self.config}", + ) + + if not self.config.mock: + for split in Split: + size_is_none = self.sizes[split.value] is None + if self.config.blend_per_split is None: + weights_are_none = self.config.blend[1] is None + else: + if self.config.blend_per_split[split.value] is None: + continue + weights_are_none = self.config.blend_per_split[split.value][1] is None + if size_is_none: + assert ( + weights_are_none + ), f"size_is_none => weights_are_none fails for {split.name} split" + + if torch.distributed.is_initialized(): + gb_rank = torch.distributed.get_rank() + vp_rank = get_virtual_pipeline_model_parallel_rank() + if gb_rank == 0 and (vp_rank == 0 or vp_rank is None): + assert ( + self.is_built_on_rank() + ), "is_built_on_rank must return True when global rank = 0 and vp rank = 0" + + def build(self) -> List[Optional[TopLevelDataset]]: """Build all dataset splits according to the provided blend(s) - + This method is distributed-aware and must be called on all ranks. - + The dataset splits returned can vary according to the config. Supply config.blend and config.split to build BlendedDataset and/or MegatronDataset splits from the same distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset - splits from separate distributions. + splits from separate distributions. In either case, for each split, handle the following + cases: + + (1) The split is None + - do nothing + + (2) The split has one contributing dataset, and... + + (a) 'size' is not None + - Build a mid-level dataset with low-level dataset sampling in proportion to the + size + + (b) 'size' is None + - Build mid-level datasets with no excess low-level dataset sampling + + (3) The split has multiple contributing datasets, and... + + (a) 'weights' is not None and 'size' is not None + - Build mid-level datasets with low-level dataset sampling in proportion to their + weights and the size + - Build a top-level dataset of length marginally greater than 'size' with mid-level + dataset sampling in proportion to their weights and the size + + (b) 'weights' is not None and 'size' is None + - Error + + (c) 'weights' is None and 'size' is not None + - Build mid-level datasets with no excess low-level dataset sampling + - Build a top-level dataset of length 'size' (capped at the sum of the mid-level + dataset lengths) with mid-level dataset sampling in proportion to their lengths + and the size + + (d) 'weights' is None and 'size' is None + - Build mid-level datasets with no excess low-level dataset sampling + - Build a top-level dataset with no excess mid-level dataset sampling Returns: - List[Optional[Union[BlendedDataset, MegatronDataset]]]: A list of either - MegatronDataset or BlendedDataset (or None) per split + List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per + split """ - return self._build_blended_dataset_splits() - - def _build_blended_dataset_splits( - self, - ) -> List[Optional[Union[BlendedDataset, MegatronDataset]]]: + datasets = self._build_blended_dataset_splits() + + for dataset in datasets: + if dataset is not None and len(dataset) > 0: + if isinstance(dataset, BlendedDataset): + if dataset.built_anew_on_cache_miss or any( + x.built_anew_on_cache_miss for x in dataset.datasets + ): + log_single_rank( + logger, + logging.INFO, + ( + f"Verifying NumPy indices for {type(dataset).__name__} " + f"{dataset.split.name} split" + ), + ) + else: + log_single_rank( + logger, + logging.INFO, + ( + f"NumPy indices for {type(dataset).__name__} {dataset.split.name} " + f"split are fully cached, skipping verification" + ), + ) + continue + # Check blend size + assert dataset.size is None or dataset.size == dataset.dataset_index.shape[0] + # Check blend access of mid-level datasets + dataset_indices, dataset_sizes = numpy.unique( + dataset.dataset_index, return_counts=True + ) + for i, (index, size) in enumerate(zip(dataset_indices, dataset_sizes)): + if len(dataset.datasets[index]) < size: + raise IndexError( + f"The {dataset.split.name} blend oversamples the contributing " + f"datasets and, e.g., requests {size} samples from " + f"{type(dataset.datasets[index]).__name__} {i} with size " + f"{len(dataset.datasets[index])}. This is unexpected. " + f"Please file an issue." + ) + + return datasets + + def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]: """Build all dataset splits according to the provided blend(s) - + See the BlendedMegatronDatasetBuilder.build alias for more information. Returns: - List[Optional[Union[BlendedDataset, MegatronDataset]]]: A list of either - MegatronDataset or BlendedDataset (or None) per split + List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per + split """ - - if getattr(self.config, "blend"): - blend = getattr(self.config, "blend") - split = getattr(self.config, "split_vector") + ## + # Return fake "mock" datasets + ## + if self.config.mock: + split = self.config.split_matrix + try: + return self._build_megatron_dataset_splits(None, split, self.sizes) + except Exception as error: + raise Exception( + f"{self.cls.__name__} failed to build as a mock data generator" + ) from error + + ## + # All splits come from the same distribution + ## + elif self.config.blend: + prefixes, weights = self.config.blend + if weights is not None: + weights = normalize(weights) + + split = self.config.split_matrix # Blend consists of a single prefix - if len(blend) == 1: - return self._build_megatron_dataset_splits(blend[0], split, self.sizes) - - # Blend consists of multiple weights and prefixes - ( - prefix_per_dataset, - weight_per_dataset, - sizes_per_dataset, - ) = _get_prefixes_weights_and_sizes_for_blend(blend, self.sizes) + if len(prefixes) == 1 and weights is None: + return self._build_megatron_dataset_splits(prefixes[0], split, self.sizes) - megatron_datasets = [[] for _ in range(len(Split))] - - for i in range(len(prefix_per_dataset)): - megatron_datasets_split = self._build_megatron_dataset_splits( - prefix_per_dataset[i], split, sizes_per_dataset[i] + # Build the mid-level datasets + if weights is None: + # Build only one "epoch" + sizes_per_dataset_buffer = [[None for split in Split] for prefix in prefixes] + else: + # The number of samples we plan to use per dataset + sizes_per_dataset_target = _get_size_per_split_per_dataset(weights, self.sizes) + # The number of samples we plan to build per dataset + sizes_per_dataset_buffer = _get_size_per_split_per_dataset( + weights, self.sizes, margin=0.5 ) - for j in range(len(megatron_datasets_split)): - megatron_datasets[j].append(megatron_datasets_split[j]) - - # Sum over all contributing datasets, per split - size_per_split = list(map(sum, zip(*sizes_per_dataset))) - blended_datasets = [] + # Build each dataset in parallel + megatron_datasets = self._build_megatron_datasets_parallel( + prefixes, split, sizes_per_dataset_buffer + ) - for i in range(len(megatron_datasets)): - is_none = map(lambda _: _ is None, megatron_datasets[i]) - - if split[i] == 0.0: - assert all(is_none) - blended_datasets.append(None) - else: - assert all(is_none) or not any(is_none) - blended_datasets.append( - self._build_generic_dataset( - BlendedDataset, - megatron_datasets[i], - weight_per_dataset, - size_per_split[i], - self.config, + # Build the top-level datasets + blended_datasets = [None] * len(Split) + for i in range(len(Split)): + if split[i] is not None: + weights_i = weights + if weights_i is not None and self.sizes[i] is not None: + # Blend according to client-specified weights and client-specified size + size_per_dataset = list(zip(*sizes_per_dataset_target))[i] + size_i = sum(size_per_dataset) + elif weights_i is None: + # Blend according to dataset sizes as-is and (maybe) client-specified size + try: + weights_i = [ + len(megatron_dataset) for megatron_dataset in megatron_datasets[i] + ] + except TypeError: + weights_i = [0 for _ in prefixes] + if self.sizes[i] is not None: + size_i = min(self.sizes[i], sum(weights_i)) + else: + # Build exhaustive indices + size_i = None + else: + raise ValueError( + "Using client-specified weights requires client-specified size" ) + blended_datasets[i] = self.build_generic_dataset( + BlendedDataset, + self.is_built_on_rank, + True, # synchronize_ranks, default behavior to build on rank-0 first + megatron_datasets[i], + weights_i, + size_i, + self.config, ) return blended_datasets + ## + # Each split comes from a separate distribution + ## else: - blended_datasets = [] + blended_datasets = [None] * len(Split) for i in range(len(Split)): - blend = getattr(self.config, "blend_per_split")[i] - - # Blend is not provided - if not blend: - blended_datasets.append(None) - continue - - split_spoof = [0.0] * len(Split) - split_spoof[i] = 1.0 + split_spoof = [None] * len(Split) + split_spoof[i] = (0.0, 1.0) sizes_spoof = [0] * len(Split) sizes_spoof[i] = self.sizes[i] - # Blend consists of a sigle prefix - if len(blend) == 1: - blended_datasets.append( - self._build_megatron_dataset_splits(blend[0], split_spoof, sizes_spoof)[i] - ) - - # Blend consists of multiple weights and prefixes - else: - ( - prefix_per_dataset, - weight_per_dataset, - sizes_per_dataset, - ) = _get_prefixes_weights_and_sizes_for_blend(blend, sizes_spoof) - - megatron_datasets = [] - for j in range(len(prefix_per_dataset)): - megatron_datasets.append( - self._build_megatron_dataset_splits( - prefix_per_dataset[j], split_spoof, sizes_per_dataset[j], - )[i] + # Blend is provided for the split + blend = self.config.blend_per_split[i] + if blend is not None: + prefixes, weights = blend + if weights is not None: + weights = normalize(weights) + + # Blend consists of a sigle prefix + if len(prefixes) == 1: + blended_datasets[i] = self._build_megatron_dataset_splits( + prefixes[0], split_spoof, sizes_spoof + )[i] + continue + + # Build mid-level datasets + if weights is None: + sizes_per_dataset_buffer = [ + [None for split in Split] for prefix in prefixes + ] + else: + # The number of samples we plan to use per dataset + sizes_per_dataset_target = _get_size_per_split_per_dataset( + weights, sizes_spoof ) - - size_per_split = list(map(sum, zip(*sizes_per_dataset))) - - blended_datasets.append( - self._build_generic_dataset( - BlendedDataset, - megatron_datasets, - weight_per_dataset, - size_per_split[i], - self.config, + # The number of samples we plan to build per dataset + sizes_per_dataset_buffer = _get_size_per_split_per_dataset( + weights, sizes_spoof, margin=0.5 ) + + # Build each dataset in parallel + megatron_datasets = self._build_megatron_datasets_parallel( + prefixes, split_spoof, sizes_per_dataset_buffer + )[i] + + # Build top-level dataset + if weights is not None and self.sizes[i] is not None: + # Blend according to client-specified weights and client-specified size + size_per_dataset = list(zip(*sizes_per_dataset_target))[i] + size = sum(size_per_dataset) + elif weights is None: + # Blend according to dataset sizes as-is and (maybe) client-specified size + try: + weights = [ + len(megatron_dataset) for megatron_dataset in megatron_datasets + ] + except TypeError: + weights = [0 for _ in prefixes] + if self.sizes[i] is not None: + size = min(self.sizes[i], sum(weights)) + else: + # Build exhaustive indices + size = None + else: + raise RuntimeError + blended_datasets[i] = self.build_generic_dataset( + BlendedDataset, + self.is_built_on_rank, + True, # synchronize_ranks, default behavior to build on rank-0 first + megatron_datasets, + weights, + size, + self.config, ) return blended_datasets - def _build_megatron_dataset_splits( - self, path_prefix: str, split: List[float], sizes: List[int], - ) -> List[Optional[MegatronDataset]]: - """Build each MegatronDataset split from a single MMapIndexedDataset + def _build_megatron_datasets_parallel( + self, prefixes: List[str], split: List[float], sizes_per_dataset: List[List[int]] + ) -> List[List[Optional[MegatronDataset]]]: + """Build the megatron datasets for a list of prefixes in parallel Args: - path_prefix (str): The MMapIndexedDataset .bin and .idx file prefix + prefixes (List[str]): The list of prefix strings split (List[float]): The dataset split ratios (must sum to 1.00) - sizes (List[int]): The number of total samples to draw from each split + sizes_per_dataset (List[List[int]]): The number of samples to request + per MegatronDataset per spilt Returns: - List[Optional[MegatronDataset]]: The MegatronDatset (or None) per split + List[List[Optional[MegatronDataset]]]: For each split, have a list of + MegatronDataset per prefix """ - indexed_dataset = self._build_generic_dataset( - MMapIndexedDataset, path_prefix, self.cls.is_multimodal() - ) - if indexed_dataset is not None: - if self.cls.is_split_by_sequence(): - split_idx_bounds = _get_split_indices( - split, indexed_dataset.sequence_lengths.shape[0] - ) - else: - split_idx_bounds = _get_split_indices( - split, indexed_dataset.document_indices.shape[0] - 1 + # Helper function to wrap the threading logic + def _threading_helper( + megatron_datasets: List[List[Optional[MegatronDataset]]], + num_workers: int, + prefixes: List[str], + split: List[float], + sizes_per_dataset: List[List[int]], + ) -> None: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + all_futures = [] + for i in range(len(prefixes)): + all_futures.append( + executor.submit( + self._build_megatron_dataset_splits, + prefixes[i], + split, + sizes_per_dataset[i], + False, # synchronize_ranks, barrier is called in this function + ) + ) + for future in all_futures: + try: + megatron_datasets_split = future.result() + for j in range(len(megatron_datasets_split)): + megatron_datasets[j].append(megatron_datasets_split[j]) + except Exception as err: + raise err + + megatron_datasets = [[] for _ in range(len(Split))] + num_dataset_builder_threads = self.config.num_dataset_builder_threads + + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + # First, build on rank 0 + if rank == 0: + num_workers = num_dataset_builder_threads + if num_workers > 1: + # since only rank 0 is running, scale up the thread count + # but not too much to avoid overloading storage on miss path. + # if user set num_dataset_builder_threads to 1, + # i.e. meant for serial build, do not scale up. + num_workers *= min(2, max(1, torch.cuda.device_count())) + _threading_helper( + megatron_datasets, num_workers, prefixes, split, sizes_per_dataset ) - split_indices = [ - numpy.arange( - start=split_idx_bounds[i], - stop=split_idx_bounds[i + 1], - step=1, - dtype=numpy.int32, + + torch.distributed.barrier() + + # Then, build on other ranks; guaranteed to be data_cache hit + if rank != 0: + _threading_helper( + megatron_datasets, + num_dataset_builder_threads, + prefixes, + split, + sizes_per_dataset, ) - for i, _ in enumerate(Split) - ] else: - split_indices = [None for _ in Split] + _threading_helper( + megatron_datasets, num_dataset_builder_threads, prefixes, split, sizes_per_dataset + ) + + return megatron_datasets - megatron_datasets = [] + def _build_megatron_dataset_splits( + self, + dataset_path: Optional[str], + split: List[float], + sizes: List[int], + synchronize_ranks: bool = True, + ) -> List[Optional[MidLevelDataset]]: + """Build each MidLevelDataset split from a single LowLevelDataset + + Args: + dataset_path (Optional[str]): The path on disk which defines the underlying + LowLevelDataset, or None for mock dataset classes + + split (List[Tuple[float, float]]): The dataset split matrix + + sizes (List[int]): The number of total samples to draw from each split + + synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks + behavior. Set to False when we enforce this behavior at higher level. + + Returns: + List[Optional[MidLevelDataset]]: The MidLevelDataset (or None) per split + """ + # short-cut if we are not building on this rank + if torch.distributed.is_initialized() and not self.is_built_on_rank(): + for i in range(len(Split)): + if split[i] is not None and synchronize_ranks: + torch.distributed.barrier() + return [None] * len(Split) + + # Build the low level dataset + low_level_dataset = self.cls.build_low_level_dataset(dataset_path, self.config) + + # Build the split indices for the low level dataset + num_elements = self.cls.numel_low_level_dataset(low_level_dataset) + split_indices = [] + for i, _ in enumerate(Split): + if split[i] is not None: + beg = int(round(split[i][0] * float(num_elements))) + end = int(round(split[i][1] * float(num_elements))) + split_indices.append(numpy.arange(start=beg, stop=end, step=1, dtype=numpy.int32)) + else: + split_indices.append(None) + + # Build the mid level dataset + mid_level_datasets = [] for i, _split in enumerate(Split): - if split[i] == 0.0: - megatron_datasets.append(None) + if split[i] is None: + mid_level_datasets.append(None) else: - megatron_datasets.append( - self._build_generic_dataset( - self.cls, indexed_dataset, split_indices[i], sizes[i], _split, self.config + mid_level_datasets.append( + self.build_generic_dataset( + self.cls, + self.is_built_on_rank, + synchronize_ranks, + low_level_dataset, + dataset_path, + split_indices[i], + sizes[i], + _split, + self.config, ) ) - return megatron_datasets + return mid_level_datasets - def _build_generic_dataset( - self, cls: Type[DistributedDataset], *args: Any, - ) -> Optional[DistributedDataset]: + @staticmethod + def build_generic_dataset( + cls: Union[Type[DistributedDataset], Callable], + is_built_on_rank: Callable, + synchronize_ranks: bool, + *args: Any, + ) -> Optional[Union[DistributedDataset, Iterable]]: """Build the DistributedDataset - Return None if and only if the underlying MegatronDataset class is not built on the current - rank and torch.distributed is initialized. + Return None if and only if the underlying dataset class is not built on the current rank + and torch.distributed is initialized. Args: - cls (Type[DistributedDataset]): The DistributedDataset class to be built + cls (Union[Type[DistributedDataset], Callable]): The DistributedDataset class to be + built. In special cases, e.g. when we are building the low level dataset for a + RawMegatronDataset instance, we can accept a Callable which returns an Iterable. + + synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks + behavior. Set to False when we enforce this behavior at higher level. args (Tuple[Any]): The positional arguments used to build the provided - DistributedDataset class + DistributedDataset class Raises: Exception: When the dataset constructor raises an OSError Returns: - Optional[DistributedDataset]: The DistributedDataset instantion or None + Optional[Union[DistributedDataset, Iterable]]: The DistributedDataset instantion, the + Iterable instantiation, or None """ if torch.distributed.is_initialized(): rank = torch.distributed.get_rank() @@ -244,22 +524,23 @@ def _build_generic_dataset( dataset = None # First, build on rank 0 - if rank == 0 and getattr(self.config, "is_built_on_rank")(): + if rank == 0 and is_built_on_rank(): try: dataset = cls(*args) except OSError as err: log = ( - f"Failed to write dataset materials to the data cache directory. " - + f"Please supply a directory to which you have write access via " - + f"the path_to_cache attribute in BlendedMegatronDatasetConfig and " - + f"retry. Refer to the preserved traceback above for more information." + f"Failed to write dataset materials to the data cache directory. Please " + f"supply a directory to which you have write access via the path_to_cache " + f"attribute in BlendedMegatronDatasetConfig and retry. Refer to the " + f"preserved traceback above for more information." ) raise Exception(log) from err - torch.distributed.barrier() + if synchronize_ranks: + torch.distributed.barrier() # After, build on other ranks - if rank != 0 and getattr(self.config, "is_built_on_rank")(): + if rank != 0 and is_built_on_rank(): dataset = cls(*args) return dataset @@ -267,62 +548,32 @@ def _build_generic_dataset( return cls(*args) -def _get_split_indices(split: List[float], num_elements: int) -> List[int]: - """Determine the document index bounds per split +def _get_size_per_split_per_dataset( + normalized_weights: List[float], target_size_per_split: List[int], margin: float = 0.0 +) -> List[List[int]]: + """Determine the contribution of the MegatronDataset splits to the BlendedDataset splits Args: - split (List[float]): The dataset split ratios (must sum to 1.00) + normalized_weights (List[float]): e.g. [0.3, 0.7] - num_elements (int): The number of elements, e.g. sequences or documents, available for - the split - - Returns: - List[int]: The indices for all three splits e.g. [0, 900, 990, 1000] for a 1000-document - set and a [90.0, 9.0, 1.0] split - """ - split_indices = [0] - for split_pct in split: - split_indices.append(split_indices[-1] + int(round(split_pct * float(num_elements)))) - split_indices[1:] = list( - map(lambda _: _ - (split_indices[-1] - num_elements), split_indices[1:]) - ) - - assert len(split_indices) == len(split) + 1 - assert split_indices[-1] == num_elements - - return split_indices + target_size_per_split (List[int]): The number of samples to target for each BlendedDataset + split - -def _get_prefixes_weights_and_sizes_for_blend( - blend: List[str], target_num_samples_per_split: List[int] -) -> Tuple[List[str], List[float], List[List[int]]]: - """Determine the contribution of the MegatronDataset splits to the BlendedDataset splits - - Args: - blend (List[str]): e.g. ["30", "path/to/dataset_1_prefix", "70", - "path/to/dataset_2_prefix"] - - target_num_samples_per_split (List[int]): The number of samples to target for each - BlendedDataset split + margin (float): The relative quantity of extra samples to build per per split per dataset, + as a percentage Returns: - Tuple[List[str], List[float], List[List[int]]]: The prefix strings e.g. - ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], the normalized weights e.g. - [0.3, 0.7], and the number of samples to request per MegatronDataset per split + List[List[int]]: The number of samples to request per MegatronDataset per split """ - weights, prefixes = zip( - *[(float(blend[i]), blend[i + 1].strip()) for i in range(0, len(blend), 2)] - ) - - weights = normalize(weights) + assert numpy.isclose(sum(normalized_weights), 1.0) - # Use 0.5% target margin to ensure we satiate the network + # Use margin as buffer to ensure we satiate the request sizes_per_dataset = [ [ - int(math.ceil(target_num_samples * weight * 1.005)) - for target_num_samples in target_num_samples_per_split + int(math.ceil(math.ceil(target_size * weight) * (1 + margin / 100))) + for target_size in target_size_per_split ] - for weight in weights + for weight in normalized_weights ] - return prefixes, weights, sizes_per_dataset + return sizes_per_dataset diff --git a/megatron/core/datasets/blended_megatron_dataset_config.py b/megatron/core/datasets/blended_megatron_dataset_config.py index b7e242a4be..a426bd3e50 100644 --- a/megatron/core/datasets/blended_megatron_dataset_config.py +++ b/megatron/core/datasets/blended_megatron_dataset_config.py @@ -1,112 +1,119 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +import functools import logging import re from dataclasses import dataclass, field -from typing import Callable, List, Optional - -import torch +from typing import List, Optional, Tuple +from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer from megatron.core.datasets.utils import Split, log_single_rank, normalize -from megatron.core.parallel_state import get_virtual_pipeline_model_parallel_rank logger = logging.getLogger(__name__) @dataclass class BlendedMegatronDatasetConfig: - """Configuration object for megatron-core blended and megatron datasets - - Attributes: - is_built_on_rank (Callable): A callable which returns True if the dataset should be built - on the current rank. It should be Megatron Core parallelism aware i.e. global rank, group - rank, and virtual rank may inform its return value. - - random_seed (int): The seed for all RNG during dataset creation. - - sequence_length (int): The sequence length. - - blend (Optional[List[str]]): The blend string, consisting of either a single dataset or a - flattened sequential sequence of weight-dataset pairs. For exampe, ["dataset-path1"] and - ["50", "dataset-path1", "50", "dataset-path2"] are both valid. Not to be used with - 'blend_per_split'. Defaults to None. - - blend_per_split (blend_per_split: Optional[List[Optional[List[str]]]]): A set of blend - strings, as defined above, one for each split distribution. Not to be used with 'blend'. - Defauls to None. + """Configuration object for Megatron Core datasets""" - split (Optional[str]): The split string, a comma separated weighting for the dataset splits - when drawing samples from a single distribution. Not to be used with 'blend_per_split'. - Defaults to None. + random_seed: int + """The seed for all RNG during dataset creation.""" - split_vector: (Optional[List[float]]): The split string, parsed and normalized post- - initialization. Not to be passed to the constructor. + sequence_length: int + """The sequence length.""" - path_to_cache (str): Where all re-useable dataset indices are to be cached. + blend: Optional[Tuple[List[str], Optional[List[float]]]] = None + """The blend, consisting of a list of dataset prefixes and optionally a list of dataset + weights. For example, [["dataset-path1", "dataset-path2"], [0.3, 0.7]]. When the weights are + None, they are inferred from the lengths of the contributing datasets. Not to be used with + 'blend_per_split'. Defaults to None. """ - is_built_on_rank: Callable - - random_seed: int + blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] = None + """A set of blends, as defined above, one for each split distribution. Not to be used with + 'blend'. Defauls to None. + """ - sequence_length: int + split: Optional[str] = None + """The split string, a comma separated weighting for the dataset splits when drawing samples + from a single distribution. Not to be used with 'blend_per_split'. Defaults to None. + """ - blend: Optional[List[str]] = None + split_matrix: Optional[List[Tuple[float, float]]] = field(init=False, default=None) + """The split matrix consisting of non-overlapping book-ends of each split in order. For more + information, refer to 'convert_split_vector_to_split_matrix'. Created automatically from + 'split'. Not to be passed in to the constructor. + """ - blend_per_split: Optional[List[Optional[List[str]]]] = None + num_dataset_builder_threads: int = 1 + """The number of threads to use for dataset building.""" - split: Optional[str] = None + path_to_cache: Optional[str] = None + """Where all re-useable dataset indices are to be cached.""" - split_vector: Optional[List[float]] = field(init=False, default=None) + mmap_bin_files: bool = True + """Whether to mmap the .bin files or use file pointers.""" - path_to_cache: str = None + mock: bool = field(init=False, default=False) + """Whether to bypass real data loading and validation in favor of mock data generation. + Created automatically from 'blend' and 'blend_per_split'. Not to be passed in to the + constructor. + """ - def __post_init__(self): - """Python dataclass method that is used to modify attributes after initialization. See - https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details. - """ - if torch.distributed.is_initialized(): - gb_rank = torch.distributed.get_rank() - vp_rank = get_virtual_pipeline_model_parallel_rank() - if gb_rank == 0 and (vp_rank == 0 or vp_rank is None): - assert ( - self.is_built_on_rank() - ), "is_built_on_rank must return True when global rank = 0 and vp rank = 0" + tokenizer: Optional[MegatronTokenizer] = None + """The MegatronTokenizer instance. Required for datasets that do online tokenization.""" + def __post_init__(self) -> None: + """Do asserts and set fields post init""" if self.blend_per_split is not None and any(self.blend_per_split): assert self.blend is None, "blend and blend_per_split are incompatible" + assert self.split is None, "split and blend_per_split are incompatible" assert len(self.blend_per_split) == len( Split ), f"blend_per_split must contain {len(Split)} blends" - if self.split is not None: - self.split = None - log_single_rank(logger, logging.WARNING, f"Let split = {self.split}") + for split in Split: + if self.blend_per_split[split.value] is None: + log_single_rank( + logger, logging.INFO, f"blend not provided for {split.name} split" + ) + else: + assert self.blend_per_split[split.value][1] is None or len( + self.blend_per_split[split.value][0] + ) == len( + self.blend_per_split[split.value][1] + ), "blend per split prefixes and weights must be equal in number" else: - assert self.blend is not None, "one of either blend or blend_per_split must be provided" - assert self.split is not None, "both blend and split must be provided" - self.split_vector = _parse_and_normalize_split(self.split) - log_single_rank(logger, logging.INFO, f"Let split_vector = {self.split_vector}") - - -@dataclass -class GPTDatasetConfig(BlendedMegatronDatasetConfig): - """Configuration object for megatron-core blended and megatron GPT datasets - - Attributes: - return_document_ids (bool): Whether to return the document ids when querying the dataset. - """ - - return_document_ids: bool = False - - -def _parse_and_normalize_split(split: str) -> List[float]: + if self.blend is not None: + assert self.blend[1] is None or len(self.blend[0]) == len( + self.blend[1] + ), "blend prefixes and weights must be equal in number" + assert self.split is not None, "split must be provided when blend is not None" + else: + self.mock = True + log_single_rank( + logger, + logging.INFO, + f"Let mock = True, as both blend and blend_per_split are None", + ) + self.split = "1,1,1" + log_single_rank( + logger, + logging.INFO, + f"Let split = {self.split}, an arbitrarily even split, as mock is True", + ) + split_vector = parse_and_normalize_split(self.split) + self.split_matrix = convert_split_vector_to_split_matrix(split_vector) + log_single_rank(logger, logging.INFO, f"Let split_matrix = {self.split_matrix}") + + +def parse_and_normalize_split(split: str) -> List[float]: """Parse the dataset split ratios from a string Args: split (str): The train valid test split string e.g. "99,1,0" Returns: - List[float]: The trian valid test split ratios e.g. [99.0, 1.0, 0.0] + List[float]: The trian valid test split ratios e.g. [0.99, 0.01, 0.0] """ split = list(map(float, re.findall(r"[.0-9]+", split))) split = split + [0.0 for _ in range(len(Split) - len(split))] @@ -117,3 +124,49 @@ def _parse_and_normalize_split(split: str) -> List[float]: split = normalize(split) return split + + +def convert_split_vector_to_split_matrix( + vector_a: List[float], vector_b: Optional[List[float]] = None +) -> List[Optional[Tuple[float, float]]]: + """Build the split matrix from one or optionally two contributing split vectors. + + Ex. a standard conversion: + + [0.99, 0.01, 0.0] -> [(0, 0.99), (0.99, 1.0), None] + + Ex. a conversion for Retro when Retro pretraining uses a [0.99, 0.01, 0.0] split and Retro + preprocessing used a [0.98, 0.02, 0.0] split: + + [0.99, 0.01, 0.0], [0.98, 0.02, 0.0] -> [(0, 0.98), (0.99, 1.0), None] + + Args: + vector_a (List[float]): The primary split vector + + vector_b (Optional[List[float]]): An optional secondary split vector which constrains the + primary split vector. Defaults to None. + + Returns: + List[Tuple[float, float]]: The split matrix consisting of book-ends of each split in order + """ + if vector_b is None: + vector_b = vector_a + + # [.900, .090, .010] -> [0.00, .900, .990, 100] + expansion_a = functools.reduce(lambda a, b: a + [a[len(a) - 1] + b], [[0], *vector_a]) + expansion_b = functools.reduce(lambda a, b: a + [a[len(a) - 1] + b], [[0], *vector_b]) + + # [0.00, .900, .990, 100.0] -> [(0.00, .900), (.900, .990), (.990, 100)] + bookends_a = list(zip(expansion_a[:-1], expansion_a[1:])) + bookends_b = list(zip(expansion_b[:-1], expansion_b[1:])) + + # gather per-split overlap or None + matrix = [] + for bookend_a, bookend_b in zip(bookends_a, bookends_b): + if min(bookend_a[1], bookend_b[1]) <= max(bookend_a[0], bookend_b[0]): + overlap = None + else: + overlap = (max(bookend_a[0], bookend_b[0]), min(bookend_a[1], bookend_b[1])) + matrix.append(overlap) + + return matrix diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index 1004e649a2..2eb7702b54 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -1,101 +1,234 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import logging import os import time -from typing import Dict, Tuple +from dataclasses import dataclass +from typing import Dict, Optional, Tuple import numpy import torch -from megatron.core.datasets.blended_megatron_dataset_config import GPTDatasetConfig -from megatron.core.datasets.indexed_dataset import MMapIndexedDataset +from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig +from megatron.core.datasets.indexed_dataset import IndexedDataset from megatron.core.datasets.megatron_dataset import MegatronDataset -from megatron.core.datasets.utils import Split, log_single_rank +from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer +from megatron.core.datasets.utils import Split +from megatron.core.datasets.utils_s3 import S3Config, is_s3_path +from megatron.core.utils import log_single_rank logger = logging.getLogger(__name__) +_PAD_TOKEN_ID = -1 + + +@dataclass +class GPTDatasetConfig(BlendedMegatronDatasetConfig): + """Configuration object for Megatron Core GPT datasets""" + + reset_position_ids: bool = None + """Option to reset the position IDs in the dataset at an interval""" + + reset_attention_mask: bool = None + """Option to reset the attention mask from the dataset""" + + eod_mask_loss: bool = None + """Option to enable the EOD mask loss""" + + create_attention_mask: bool = True + """Option to enable the attention masks generation. Can be disabled if attention kernel + generates masks by itself. + """ + + drop_last_partial_validation_sequence: bool = True + """Option to drop the last partial validation sequence""" + + add_extra_token_to_sequence: bool = True + """Option to draw sequences with one extra token to ensure the sample input tokens and sample + output tokens are both of the desired sequence length + """ + + s3_cache_path: str = None + """Path for caching indices for s3 dataloading.""" + + def __post_init__(self) -> None: + """Do asserts and set fields post init""" + super().__post_init__() + + assert self.tokenizer is not None + + assert self.reset_position_ids is not None + assert self.reset_attention_mask is not None + assert self.eod_mask_loss is not None + class GPTDataset(MegatronDataset): """The base GPT dataset Args: - indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the - MegatronDataset + indexed_dataset (IndexedDataset): The IndexedDataset around which to build the GPTDataset + + dataset_path (Optional[str]): The real path on disk to the dataset, for bookkeeping indexed_indices (numpy.ndarray): The set of the documents indices to expose - num_samples (int): The number of samples to draw from the indexed dataset + num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When + None, build as many samples as correspond to one epoch. index_split (Split): The indexed_indices Split - config (GPTDatasetConfig): The GPT-specific container for all config sourced parameters + config (GPTDatasetConfig): The config """ def __init__( self, - indexed_dataset: MMapIndexedDataset, + indexed_dataset: IndexedDataset, + dataset_path: Optional[str], indexed_indices: numpy.ndarray, - num_samples: int, + num_samples: Optional[int], index_split: Split, config: GPTDatasetConfig, ) -> None: - super().__init__(indexed_dataset, indexed_indices, num_samples, index_split, config) + super().__init__( + indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config + ) + self.masks_and_position_ids_are_cacheable = not any( + [ + self.config.reset_position_ids, + self.config.reset_attention_mask, + self.config.eod_mask_loss, + ] + ) + self.masks_and_position_ids_are_cached = False + self.cached_attention_mask = None + self.cached_loss_mask = None + self.cached_position_ids = None + + try: + self._pad_token_id = self.config.tokenizer.pad + except Exception: + self._pad_token_id = _PAD_TOKEN_ID + + (self.document_index, self.sample_index, self.shuffle_index) = ( + self._build_document_sample_shuffle_indices() + ) - def _finalize(self) -> None: + @staticmethod + def numel_low_level_dataset(low_level_dataset: IndexedDataset) -> int: """Abstract method implementation - - Load or build/cache the document, sample, and shuffle indices - """ - assert isinstance(self.config, GPTDatasetConfig) - ( - self.document_index, - self.sample_index, - self.shuffle_index, - ) = self._build_document_sample_shuffle_indices() + For GPT, the underlying IndexedDataset should be split by sequence, as opposed to, say, + BERT, which should be split by document - def __len__(self) -> int: - """Abstract method implementation + Args: + low_level_dataset (IndexedDataset): The underlying IndexedDataset Returns: - int: The length of the dataset + int: The number of unique elements in the underlying IndexedDataset """ - return self.sample_index.shape[0] - 1 + return low_level_dataset.sequence_lengths.shape[0] - def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: + @staticmethod + def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> IndexedDataset: """Abstract method implementation Args: - idx (int): The index into the dataset + dataset_path (str): The real path prefix to the IndexedDataset .bin and .idx files + + config (GPTDatasetConfig): The config Returns: - Dict[str, numpy.ndarray]: The text ids and (optionally) the document ids wrapped in a - dictionary + IndexedDataset: The underlying IndexedDataset """ - text, document_ids = self._query_document_sample_shuffle_indices(idx) - if getattr(self.config, "return_document_ids"): - return {"text": text, "document_ids": document_ids} - else: - return {"text": text} + if is_s3_path(dataset_path): + return IndexedDataset( + dataset_path, + multimodal=False, + mmap=config.mmap_bin_files, + s3_config=S3Config(path_to_idx_cache=config.s3_cache_path), + ) + return IndexedDataset(dataset_path, multimodal=False, mmap=config.mmap_bin_files) - @staticmethod - def is_multimodal() -> bool: + def __len__(self) -> int: """Abstract method implementation Returns: - bool: False + int: The length of the dataset """ - return False + return self.sample_index.shape[0] - 1 - @staticmethod - def is_split_by_sequence() -> bool: + def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: """Abstract method implementation + Args: + idx (Optioal[int]): The index into the dataset + Returns: - bool: True + Dict[str, torch.Tensor]: The sample information wrapped in a dictionary """ - return True + if idx is None: + # Batch padding sequence so the index does not matter + text, _ = self._query_document_sample_shuffle_indices(0) + else: + text, _ = self._query_document_sample_shuffle_indices(idx) + + text = torch.from_numpy(text).long() + if self.config.add_extra_token_to_sequence: + tokens = text[:-1].contiguous() + labels = text[1:].contiguous() + else: + tokens = text + labels = torch.roll(text, shifts=-1, dims=0) + labels[-1] = self._pad_token_id + + if ( + not self.masks_and_position_ids_are_cacheable + or not self.masks_and_position_ids_are_cached + ): + attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids( + tokens, + self.config.tokenizer.eod, + self.config.reset_position_ids, + self.config.reset_attention_mask, + self.config.eod_mask_loss, + self.config.create_attention_mask, + ) + if self.masks_and_position_ids_are_cacheable: + self.cached_attention_mask = attention_mask + self.cached_loss_mask = loss_mask + self.cached_position_ids = position_ids + self.masks_and_position_ids_are_cached = True + else: + attention_mask = self.cached_attention_mask + loss_mask = self.cached_loss_mask + position_ids = self.cached_position_ids + + # For padded sequences, mask the loss + loss_mask[labels == self._pad_token_id] = 0.0 + + # For padded sequences, ensure the embedding layer can map the token ID + tokens[tokens == self._pad_token_id] = 0 + labels[labels == self._pad_token_id] = 0 + + # Batch padding sequence so we mask the loss + if idx is None: + loss_mask = torch.zeros_like(loss_mask) + + if self.config.create_attention_mask: + return { + "tokens": tokens, + "labels": labels, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + "position_ids": position_ids, + } + else: + return { + "tokens": tokens, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + } def _query_document_sample_shuffle_indices( self, idx: int @@ -125,10 +258,12 @@ def _query_document_sample_shuffle_indices( # Add the entire sample sample_parts.append( - self.indexed_dataset.get( + self.dataset.get( self.document_index[doc_index_beg], offset=doc_index_beg_offset, - length=doc_index_end_offset - doc_index_beg_offset + 1, + length=doc_index_end_offset + - doc_index_beg_offset + + self.config.add_extra_token_to_sequence, ) ) @@ -140,13 +275,29 @@ def _query_document_sample_shuffle_indices( # Add the sample part offset = 0 if i > doc_index_beg else doc_index_beg_offset - length = None if i < doc_index_end else doc_index_end_offset + 1 + length = ( + None + if i < doc_index_end + else doc_index_end_offset + self.config.add_extra_token_to_sequence + ) sample_parts.append( - self.indexed_dataset.get(self.document_index[i], offset=offset, length=length) + self.dataset.get(self.document_index[i], offset=offset, length=length) ) + assert len(document_ids) == len( + sample_parts + ), f"len(document_ids) ({len(document_ids)}) != len(sample_parts) ({len(sample_parts)})" + + length = sum(map(len, sample_parts)) + + # Pad the sample if necessary + if length < (self.config.sequence_length + self.config.add_extra_token_to_sequence): + sample_parts.append( + [self._pad_token_id] + * (self.config.sequence_length + self.config.add_extra_token_to_sequence - length) + ) return ( - numpy.array(numpy.concatenate(sample_parts), dtype=numpy.int64), + numpy.concatenate(sample_parts, dtype=numpy.int64), numpy.array(document_ids, dtype=numpy.int64), ) @@ -154,7 +305,7 @@ def _build_document_sample_shuffle_indices( self, ) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: """Build the document index, the sample index, and the shuffle index - + The document index: -- 1-D -- An ordered array of document ids @@ -168,58 +319,65 @@ def _build_document_sample_shuffle_indices( -- A random permutation of index range of the sample index Returns: - Tuple[numpy.ndarray, numpy.ndarray]: The document index, the sample index, and the - shuffle index - - TODO: Explain the 80% threshold + Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: The document index, the sample + index, and the shuffle index """ - path_to_cache = getattr(self.config, "path_to_cache") - if path_to_cache is None: + path_to_cache = self.config.path_to_cache + if path_to_cache is None and not self.config.mock: path_to_cache = os.path.join( - self.indexed_dataset.path_prefix, "cache", f"{type(self).__name__}_indices" + self.dataset.path_prefix, "cache", f"{type(self).__name__}_indices" ) - get_path_to = lambda suffix: os.path.join( - path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}" - ) - path_to_description = get_path_to("description.txt") - path_to_document_index = get_path_to("document_index.npy") - path_to_sample_index = get_path_to("sample_index.npy") - path_to_shuffle_index = get_path_to("shuffle_index.npy") - cache_hit = all( - map( - os.path.isfile, - [ - path_to_description, - path_to_document_index, - path_to_sample_index, - path_to_shuffle_index, - ], + if path_to_cache: + base = f"{self.unique_description_hash}-{type(self).__name__}-{self.index_split.name}" + get_path_to = lambda affix: os.path.join(path_to_cache, f"{base}-{affix}") + path_to_description = get_path_to("description.txt") + path_to_document_index = get_path_to("document_index.npy") + path_to_sample_index = get_path_to("sample_index.npy") + path_to_shuffle_index = get_path_to("shuffle_index.npy") + cache_hit = all( + map( + os.path.isfile, + [ + path_to_description, + path_to_document_index, + path_to_sample_index, + path_to_shuffle_index, + ], + ) ) - ) - - num_tokens_per_epoch = _get_num_tokens_per_epoch(self.indexed_dataset, self.indexed_indices) - - sequence_length = getattr(self.config, "sequence_length") + else: + cache_hit = False - num_epochs = _get_num_epochs(num_tokens_per_epoch, sequence_length, self.num_samples) + if not path_to_cache or ( + not cache_hit + and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0) + ): - if not cache_hit and torch.distributed.get_rank() == 0: log_single_rank( logger, logging.INFO, f"Build and save the {type(self).__name__} {self.index_split.name} indices", ) + self.built_anew_on_cache_miss = True + t_beg = time.time() + + sequence_length = self.config.sequence_length + num_tokens_per_epoch = self._get_num_tokens_per_epoch() + num_epochs = self._get_num_epochs(num_tokens_per_epoch) if num_epochs == 1: separate_final_epoch = False else: # Get the number of samples for the last epoch num_samples_sans_final_epoch = ( - (num_epochs - 1) * num_tokens_per_epoch - 1 + (num_epochs - 1) * num_tokens_per_epoch + - self.config.add_extra_token_to_sequence ) // sequence_length num_samples_from_final_epoch = self.num_samples - num_samples_sans_final_epoch - num_samples_per_epoch = (num_tokens_per_epoch - 1) // sequence_length + num_samples_per_epoch = ( + num_tokens_per_epoch - self.config.add_extra_token_to_sequence + ) // sequence_length # num_samples_from_final_epoch should be non-negative assert num_samples_from_final_epoch >= 0 @@ -247,57 +405,49 @@ def _build_document_sample_shuffle_indices( logger, logging.DEBUG, f"> separate_final_epoch: {separate_final_epoch}" ) - numpy_random_state = numpy.random.RandomState(getattr(self.config, "random_seed")) - - os.makedirs(path_to_cache, exist_ok=True) - - # Write the description - with open(path_to_description, "wt") as writer: - writer.write(self.unique_description) + numpy_random_state = numpy.random.RandomState(self.config.random_seed) # Build the document index - log_single_rank( - logger, - logging.INFO, - f"\tBuild and save the document index to {os.path.basename(path_to_document_index)}", - ) - t_beg = time.time() document_index = _build_document_index( - self.indexed_indices, num_epochs, numpy_random_state, separate_final_epoch + self.indices, num_epochs, numpy_random_state, separate_final_epoch ) - numpy.save(path_to_document_index, document_index, allow_pickle=True) - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") + + drop_last_partial_sequence = True + if self.index_split == Split.valid: + drop_last_partial_sequence = self.config.drop_last_partial_validation_sequence # Build the sample index - log_single_rank( - logger, - logging.INFO, - f"\tBuild and save the sample index to {os.path.basename(path_to_sample_index)}", - ) - t_beg = time.time() from megatron.core.datasets import helpers + if self.index_split == Split.valid: + drop_last_partial_sequence = self.config.drop_last_partial_validation_sequence + else: + drop_last_partial_sequence = True + assert document_index.dtype == numpy.int32 - assert self.indexed_dataset.sequence_lengths.dtype == numpy.int32 + assert self.dataset.sequence_lengths.dtype == numpy.int32 + if len(document_index) * 2 > len(self.dataset.sequence_lengths): + # If "access density" of sequence_lengths is high, force load the mmap-ed array + # into memory by making a copy. + # + # System performance benefits come from two aspects: + # 1. We sequentially pre-load the whole file, most of which we expect to read + # 2. The GIL is held when entering the c++ program, improving the speed of which + # improves parallelism + sequence_lengths_for_cpp = self.dataset.sequence_lengths.copy() + else: + sequence_lengths_for_cpp = self.dataset.sequence_lengths sample_index = helpers.build_sample_idx( - self.indexed_dataset.sequence_lengths, + sequence_lengths_for_cpp, document_index, sequence_length, num_epochs, num_tokens_per_epoch, + drop_last_partial_sequence, + self.config.add_extra_token_to_sequence, ) - numpy.save(path_to_sample_index, sample_index, allow_pickle=True) - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") # Build the shuffle index - log_single_rank( - logger, - logging.INFO, - f"\tBuild and save the shuffle index to {os.path.basename(path_to_shuffle_index)}", - ) - t_beg = time.time() if separate_final_epoch: shuffle_index = _build_shuffle_index( num_samples_sans_final_epoch, sample_index.shape[0] - 1, numpy_random_state @@ -306,10 +456,32 @@ def _build_document_sample_shuffle_indices( shuffle_index = _build_shuffle_index( sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state ) - numpy.save(path_to_shuffle_index, shuffle_index, allow_pickle=True) + + if path_to_cache: + os.makedirs(path_to_cache, exist_ok=True) + # Write the description + with open(path_to_description, "wt") as writer: + writer.write(self.unique_description) + numpy.save(path_to_document_index, document_index, allow_pickle=True) + numpy.save(path_to_sample_index, sample_index, allow_pickle=True) + numpy.save(path_to_shuffle_index, shuffle_index, allow_pickle=True) + else: + log_single_rank( + logger, + logging.WARNING, + f"Unable to save {type(self).__name__} indexes because path_to_cache is None", + ) + t_end = time.time() log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") + log_single_rank( + logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}" + ) + log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}") + + return document_index, sample_index, shuffle_index + log_single_rank( logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices" ) @@ -347,48 +519,38 @@ def _build_document_sample_shuffle_indices( log_single_rank( logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}" ) - log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}") return document_index, sample_index, shuffle_index + def _get_num_tokens_per_epoch(self) -> int: + """Calculate the number of tokens in a single epoch -def _get_num_tokens_per_epoch(indexed_dataset: MMapIndexedDataset, indices: numpy.ndarray) -> int: - """Calculate the number of tokens in a single epoch - - Args: - indexed_dataset (MMapIndexedDataset): The underlying MMapIndexedDataset - - indices (numpy.ndarray): The subset of indices into the underlying MMapIndexedDataset - - Returns: - int: The number of tokens in a single epoch - """ - return numpy.sum(indexed_dataset.sequence_lengths[indices]) - - -def _get_num_epochs(num_tokens_per_epoch: int, seq_length: int, num_samples: int) -> int: - """Calculate the number of epochs - - Args: - num_tokens_per_epoch (int): The number of tokens in a single epoch + Returns: + int: The number of tokens in a single epoch + """ + return int(numpy.sum(self.dataset.sequence_lengths[self.indices])) - seq_length (int): The sequence length in tokens + def _get_num_epochs(self, num_tokens_per_epoch: int) -> int: + """Calculate the number of epochs - num_samples (int): The total number of samples + Args: + num_tokens_per_epoch (int): The number of tokens in a single epoch - Returns: - int: The number of epochs - """ - num_epochs = 0 - num_tokens = 0 - while True: - num_epochs += 1 - num_tokens += num_tokens_per_epoch - # -1 is because we need to retrieve seq_length + 1 token each time - # but the last token will overlap with the first token of the next - # sample except for the last sample. - if ((num_tokens - 1) // seq_length) >= num_samples: + Returns: + int: The number of epochs + """ + num_epochs = 1 + num_tokens = num_tokens_per_epoch + if self.num_samples is None: return num_epochs + else: + num_tokens_requested = ( + self.num_samples * self.config.sequence_length + ) + self.config.add_extra_token_to_sequence + while num_tokens < num_tokens_requested: + num_epochs += 1 + num_tokens += num_tokens_per_epoch + return num_epochs def _build_document_index( @@ -410,8 +572,6 @@ def _build_document_index( Returns: numpy.ndarray: The document index - - TODO: Explain separate_final_epoch """ if not separate_final_epoch or num_epochs == 1: document_index = numpy.mgrid[0:num_epochs, 0 : len(documents)][1] @@ -435,15 +595,12 @@ def _build_shuffle_index( num_samples (int): The size of the first shuffle range [0, num_samples) total_size (int): The size of the entire index. If larger than 'num_samples', it defines - - the second shuffle range [num_samples, total_size) + the second shuffle range [num_samples, total_size) numpy_random_state (numpy.random.RandomState): The NumPy random state Returns: numpy.ndarray: The shuffle index - - TODO: Explain [0, num_samples) [num_samples, total_size) split """ dtype_ = numpy.uint32 if total_size >= (numpy.iinfo(numpy.uint32).max - 1): @@ -458,3 +615,196 @@ def _build_shuffle_index( numpy_random_state.shuffle(shuffle_idx_last) return numpy.concatenate((shuffle_idx_first, shuffle_idx_last)) + + +def _get_ltor_masks_and_position_ids( + data: torch.Tensor, + eod_token: int, + reset_position_ids: bool, + reset_attention_mask: bool, + eod_mask_loss: bool, + create_attention_mask: bool, +): + """Build masks and position id for left to right model. + + Args: + data (torch.Tensor): The data tenor that holds the tokens from the dataset + + eod_token (int): ID of the token to that is considered the EOD + + reset_position_ids (bool): Switch to reset the document position ID's + + reset_attention_mask (bool): Switch to reset the attention mask + + eod_mask_loss (bool): Switch to enable the EOD mask loss + + create_attention_mask (bool): Switch to enable the attention masks generation. Can be + disabled if attention kernel generates masks by itself. + + Returns: + torch.Tensor: Attention mask needed to be used for Attention + + torch.Tensor: The mask used for loss value during training + + torch.Tensor: The position ID's of the token + """ + seq_length = data.numel() + + if create_attention_mask: + attention_mask = torch.tril( + torch.ones((seq_length, seq_length), device=data.device) + ).unsqueeze(0) + else: + attention_mask = None + + # Loss mask. + loss_mask = torch.ones(seq_length, dtype=torch.float, device=data.device) + if eod_mask_loss: + loss_mask[data == eod_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Find indices where EOD token is. + eod_index = position_ids[data == eod_token] + # Detach indices from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indices: + prev_index = 0 + for j in range(eod_index.numel()): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask and attention_mask is not None: + attention_mask[0, (i + 1) :, : (i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[(i + 1) :] -= i + 1 - prev_index + prev_index = i + 1 + + if attention_mask is not None: + # Convert attention mask to binary: + attention_mask = attention_mask < 0.5 + + return attention_mask, loss_mask, position_ids + + +class MockGPTLowLevelDataset: + """The mock GPT low level dataset + + This class is meant to generate tokenized data in the classic "Megatron-LM" GPT style. Notably, + we add the end of document token to each element indexed in __getitem__ + + Args: + tokenizer (MegatronTokenizer): The tokenizer the special token information of which we use + to augment the mock data. + """ + + seed: int = 0 + """The hard-coded random seed to use to set the NumPy RNG""" + + size: int = 100000 + """The hard-coded number of samples to generate""" + + max_sequence_length: int = 4096 + """The hard-coded max sequence length to generate""" + + def __init__(self, tokenizer: MegatronTokenizer) -> None: + self.tokenizer = tokenizer + rng = numpy.random.default_rng(seed=self.seed) + self.sequence_lengths = rng.integers( + low=1, high=self.max_sequence_length, size=self.size, dtype=numpy.int32 + ) + + def __len__(self) -> int: + return self.size + + def __getitem__(self, idx: int) -> numpy.number: + length = self.sequence_lengths[idx] + sample = numpy.int64( + numpy.concatenate([numpy.arange(length - 1) + 1, [self.tokenizer.eod]]) + ) + return sample + + def get(self, idx: int, offset: int = 0, length: Optional[int] = None) -> numpy.ndarray: + """This function is n abstraction over __getitem__ with support for slicing + + Args: + idx (int): The index into the dataset + + offset (int): The integer token offset in the sequence + + length (Optional[int]): The number of tokens to grab from the sequence + + Returns: + numpy.ndarray: The sequence tokens at the index + """ + if length is None: + length = self.sequence_lengths[idx] - offset + return self[idx][offset : offset + length] + + +class MockGPTDataset(GPTDataset): + """The mock GPT dataset + + Args: + indexed_dataset (MockGPTLowLevelDataset): The MockGPTLowLevelDataset around which to build + the MockGPTDataset + + dataset_path (Optional[str]): This argument is of no consequence for the MockGPTDataset + + indices (numpy.ndarray): The set of the dataset indices to expose + + num_samples (int): The number of samples to draw from the dataset + + index_split (Split): The indices Split + + config (GPTDatasetConfig): The config + """ + + def __init__( + self, + dataset: MockGPTLowLevelDataset, + dataset_path: Optional[str], + indices: numpy.ndarray, + num_samples: int, + index_split: Split, + config: GPTDatasetConfig, + ) -> None: + assert config.mock + + super().__init__(dataset, dataset_path, indices, num_samples, index_split, config) + + @staticmethod + def numel_low_level_dataset(low_level_dataset: MockGPTLowLevelDataset) -> int: + """Abstract method implementation + + Args: + low_level_dataset (MockGPTLowLevelDataset): The underlying MockGPTLowLevelDataset + + Returns: + int: The number of unique elements in the underlying MockGPTLowLevelDataset + """ + return len(low_level_dataset) + + @staticmethod + def build_low_level_dataset( + dataset_path: Optional[str], config: GPTDatasetConfig + ) -> MockGPTLowLevelDataset: + """Abstract method implementation + + Args: + dataset_path (Optional[str]): This argument is of no consequence for the + MockGPTLowLevelDataset + + config (GPTDatasetConfig): The config + + Returns: + MockGPTLowLevelDataset: The underlying MockGPTLowLevelDataset + """ + return MockGPTLowLevelDataset(config.tokenizer) diff --git a/megatron/core/datasets/helpers.cpp b/megatron/core/datasets/helpers.cpp index 4e1b3dbc93..1a3e8448f3 100644 --- a/megatron/core/datasets/helpers.cpp +++ b/megatron/core/datasets/helpers.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -16,6 +17,61 @@ using namespace std; const int32_t LONG_SENTENCE_LEN = 512; + +void build_exhaustive_blending_indices(py::array_t &dataset_index, py::array_t &dataset_sample_index, const py::array_t &sizes, const int32_t num_datasets) { + /* + Build blending indices by sampling exactly as many samples from dataset[i] + as is requested by sizes[i] for all i in the range [0, num_datasets). + */ + auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); + auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); + auto sizes_ptr = sizes.unchecked<1>(); + + int64_t total_size = 0; + int64_t dataset_sample_counts[num_datasets]; + std::set dataset_unspent_indices; + for (int32_t i = 0; i < num_datasets; ++i) { + total_size += sizes_ptr[i]; + dataset_sample_counts[i] = 0; + dataset_unspent_indices.insert(i); + } + + // still need fractional weights to sample in proportion to sizes + double weights[num_datasets]; + for (int32_t i = 0; i < num_datasets; ++i) { + weights[i] = sizes_ptr[i] / static_cast(total_size); + } + + int64_t index_sample = 0; + while (dataset_unspent_indices.size() > 0) { + double index_sample_double = std::max(static_cast(index_sample), 1.0); + + int64_t error_argmax; + double error_max = std::numeric_limits::lowest(); + + for (int32_t index_dataset : dataset_unspent_indices) { + double error = weights[index_dataset] * index_sample_double - static_cast(dataset_sample_counts[index_dataset]); + if (error > error_max) { + error_argmax = index_dataset; + error_max = error; + } + } + + // Populate the indices. + dataset_index_ptr[index_sample] = static_cast(error_argmax); + dataset_sample_index_ptr[index_sample] = dataset_sample_counts[error_argmax]; + + // Update the total samples. + dataset_sample_counts[error_argmax] += 1; + + if (sizes_ptr[error_argmax] - static_cast(dataset_sample_counts[error_argmax]) == 0) { + dataset_unspent_indices.erase(error_argmax); + } + + index_sample += 1; + } +} + void build_blending_indices(py::array_t &dataset_index, py::array_t &dataset_sample_index, const py::array_t &weights, @@ -83,17 +139,22 @@ void build_blending_indices(py::array_t &dataset_index, } } -py::array build_sample_idx(const py::array_t &sizes_, - const py::array_t &doc_idx_, - const int32_t seq_length, - const int32_t num_epochs, - const int64_t tokens_per_epoch) -{ - /* Sample index (sample_idx) is used for gpt2 like dataset for which - the documents are flattened and the samples are built based on this - 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] - where [..., 0] contains the index into `doc_idx` and [..., 1] is the - starting offset in that document.*/ +template +py::array_t build_sample_idx( + const py::array_t &sizes_, + const py::array_t &document_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch, + const bool drop_last_partial_sequence = true, + const int add_extra_token_to_sequence = 1 +){ + /* + Sample index (sample_idx) is used for gpt2 like dataset for which the documents are flattened + and the samples are built based on this 1-D flatten array. It is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains the index into `doc_idx` and [..., 1] is + the starting offset in that document. + */ // Consistency checks. assert(seq_length > 1); @@ -102,68 +163,86 @@ py::array build_sample_idx(const py::array_t &sizes_, // Remove bound checks. auto sizes = sizes_.unchecked<1>(); - auto doc_idx = doc_idx_.unchecked<1>(); + auto document_idx = document_idx_.unchecked<1>(); - // Mapping and it's length (1D). - int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; - int32_t *sample_idx = new int32_t[2 * (num_samples + 1)]; + // Build the sample idx as a contiguous 1-D array of type T. + int64_t num_samples = 0; + if (drop_last_partial_sequence == true) { + num_samples = (num_epochs * tokens_per_epoch - add_extra_token_to_sequence) / seq_length; + } + else { + num_samples = ceil(float(num_epochs * tokens_per_epoch - add_extra_token_to_sequence) / seq_length); + } + T *sample_idx = new T[2 * (num_samples + 1)]; // Index into sample_idx. - int64_t sample_index = 0; - // Index into doc_idx. - int64_t doc_idx_index = 0; + int64_t sample_idx_index = 0; + // Index into document_idx. + T document_idx_index = 0; // Begining offset for each document. - int32_t doc_offset = 0; + T doc_offset = 0; // Start with first document and no offset. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; + sample_idx[2 * sample_idx_index] = document_idx_index; + sample_idx[2 * sample_idx_index + 1] = doc_offset; + ++sample_idx_index; - while (sample_index <= num_samples) + while (sample_idx_index <= num_samples) { // Start with a fresh sequence. - int32_t remaining_seq_length = seq_length + 1; + int32_t remaining_seq_length = seq_length + add_extra_token_to_sequence; while (remaining_seq_length != 0) { // Get the document length. - auto doc_id = doc_idx[doc_idx_index]; - auto doc_length = sizes[doc_id] - doc_offset; + auto document_index = document_idx[document_idx_index]; + auto document_length = sizes[document_index] - doc_offset; // And add it to the current sequence. - remaining_seq_length -= doc_length; + remaining_seq_length -= document_length; // If we have more than a full sequence, adjust offset and set // remaining length to zero so we return from the while loop. // Note that -1 here is for the same reason we have -1 in // `_num_epochs` calculations. if (remaining_seq_length <= 0) { - doc_offset += (remaining_seq_length + doc_length - 1); + doc_offset += (remaining_seq_length + document_length - add_extra_token_to_sequence); remaining_seq_length = 0; } else { // Otherwise, start from the begining of the next document. - ++doc_idx_index; + if (document_idx_index == (document_idx_.shape(0) - 1)) + { + // If we have reached the end of the documents, break. + assert(sample_idx_index == num_samples); + doc_offset = sizes[document_idx[document_idx_index]] - add_extra_token_to_sequence; + break; + } + ++document_idx_index; doc_offset = 0; } } // Record the sequence. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; + sample_idx[2 * sample_idx_index] = document_idx_index; + sample_idx[2 * sample_idx_index + 1] = doc_offset; + ++sample_idx_index; } // Method to deallocate memory. - py::capsule free_when_done(sample_idx, [](void *mem_) - { - int32_t *mem = reinterpret_cast(mem_); - delete[] mem; }); + py::capsule free_when_done( + sample_idx, + [](void *mem_){ + T *mem = reinterpret_cast(mem_); + delete[] mem; + } + ); // Return the numpy array. - const auto byte_size = sizeof(int32_t); - return py::array(std::vector{num_samples + 1, 2}, // shape - {2 * byte_size, byte_size}, // C-style contiguous strides - sample_idx, // the data pointer - free_when_done); // numpy array references + const auto byte_size = sizeof(T); + return py::array_t( + std::vector{num_samples + 1, 2}, // shape + {2 * byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done // numpy array references + ); } inline int32_t get_target_sample_len(const int32_t short_seq_ratio, @@ -756,10 +835,12 @@ py::array build_blocks_mapping(const py::array_t &docs_, } } -PYBIND11_MODULE(helpers, m) +PYBIND11_MODULE(helpers_cpp, m) { m.def("build_mapping", &build_mapping); m.def("build_blocks_mapping", &build_blocks_mapping); - m.def("build_sample_idx", &build_sample_idx); + m.def("build_sample_idx_int32", &build_sample_idx); + m.def("build_sample_idx_int64", &build_sample_idx); m.def("build_blending_indices", &build_blending_indices); + m.def("build_exhaustive_blending_indices", &build_exhaustive_blending_indices); } diff --git a/megatron/core/datasets/helpers.py b/megatron/core/datasets/helpers.py new file mode 100644 index 0000000000..9978a6050a --- /dev/null +++ b/megatron/core/datasets/helpers.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import numpy + +# Implicit imports for backwards compatibility +# Explicit imports for readability +from megatron.core.datasets.helpers_cpp import * +from megatron.core.datasets.helpers_cpp import build_sample_idx_int32, build_sample_idx_int64 + + +def build_sample_idx( + sizes: numpy.ndarray, + document_indices: numpy.ndarray, + sequence_length: int, + num_epochs: int, + tokens_per_epoch: int, + drop_last_partial_sequence: bool = True, + add_extra_token_to_sequence: bool = True, +): + """Build the 2-D sample index using the properly typed templated C++ function from helpers.cpp + + Args: + sizes (numpy.ndarray): The 1-D array of document lengths + + document_indices (numpy.ndarray): The 1-D array of document indices + + sequence_length (int): The sequence length + + num_epochs (int): The number of epochs + + tokens_per_epoch (int): The number of tokens per epoch + + drop_last_partial_sequence (bool): Whether to omit the last partial sequence in the sample + index should it exist. Defaults to True. + + add_extra_token_to_sequence (bool): Whether to build samples with sequence length + `sequence_length + 1`. Defaults to True. + + Returns: + numpy.ndarray: The 2-D sample index + """ + sample_idx_max = max(document_indices.shape[0], sizes.max()) + if sample_idx_max <= numpy.iinfo(numpy.int32).max: + sample_idx = build_sample_idx_int32( + sizes, + document_indices, + sequence_length, + num_epochs, + tokens_per_epoch, + drop_last_partial_sequence, + 1 if add_extra_token_to_sequence else 0, + ) + assert sample_idx.min() >= 0 and sample_idx.max() <= sample_idx_max + else: + sample_idx = build_sample_idx_int64( + sizes, + document_indices, + sequence_length, + num_epochs, + tokens_per_epoch, + drop_last_partial_sequence, + 1 if add_extra_token_to_sequence else 0, + ) + return sample_idx diff --git a/megatron/core/datasets/indexed_dataset.py b/megatron/core/datasets/indexed_dataset.py index cd62160cea..29975336f1 100644 --- a/megatron/core/datasets/indexed_dataset.py +++ b/megatron/core/datasets/indexed_dataset.py @@ -10,16 +10,28 @@ import shutil import struct import time +from abc import ABC, abstractmethod from enum import Enum from functools import lru_cache from itertools import accumulate from types import TracebackType from typing import List, Optional, Tuple, Type, Union +try: + import boto3 +except ModuleNotFoundError: + pass import numpy import torch -from megatron.core.datasets.utils import log_single_rank +from megatron.core.datasets.utils_s3 import ( + S3Config, + is_s3_path, + maybe_download_file, + object_exists, + parse_s3_path, +) +from megatron.core.utils import log_single_rank logger = logging.getLogger(__name__) @@ -27,8 +39,7 @@ class DType(Enum): - """The NumPy data type Enum for writing/reading the MMapIndexedDataset indices - """ + """The NumPy data type Enum for writing/reading the IndexedDataset indices""" uint8 = 1 int8 = 2 @@ -300,10 +311,10 @@ def __init__(self, idx_path: str, multimodal: bool) -> None: ) def __del__(self) -> None: - """Clean up the object - """ - self.bin_buffer_mmap._mmap.close() - del self.bin_buffer_mmap + """Clean up the object""" + if hasattr(self, "bin_buffer_mmap"): + self.bin_buffer_mmap._mmap.close() + del self.bin_buffer_mmap def __len__(self) -> int: """Return the length of the dataset @@ -321,8 +332,7 @@ def __getitem__(self, idx: int) -> Tuple[numpy.int32, numpy.int64, Optional[nump idx (int): The index into the dataset Returns: - Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]: The pointer, length and mode at - the index + Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]: The pointer, length and mode at the index """ return ( self.sequence_pointers[idx], @@ -331,66 +341,264 @@ def __getitem__(self, idx: int) -> Tuple[numpy.int32, numpy.int64, Optional[nump ) -class MMapIndexedDataset(torch.utils.data.Dataset): +class _BinReader(ABC): + """Abstract class to read the data (.bin) file""" + + @abstractmethod + def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndarray: + """Read bytes into a numpy array. + + Args: + dtype (Type[numpy.number]): Data-type of the returned array. + + count (int): Number of items to read. + + offset (int): Start reading from this offset (in bytes). + + Returns: + numpy.ndarray: An array with `count` items and data-type `dtype` constructed from reading bytes from the data file starting at `offset`. + """ + pass + + +class _MMapBinReader(_BinReader): + """A _BinReader that memory maps the data (.bin) file + + Args: + bin_path (str): bin_path (str): The path to the data (.bin) file. + """ + + def __init__(self, bin_path: str) -> None: + self._bin_buffer_mmap = numpy.memmap(bin_path, mode="r", order="C") + self._bin_buffer = memoryview(self._bin_buffer_mmap) + + def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndarray: + """Read bytes into a numpy array. + + Args: + dtype (Type[numpy.number]): Data-type of the returned array. + + count (int): Number of items to read. + + offset (int): Start reading from this offset (in bytes). + + Returns: + numpy.ndarray: An array with `count` items and data-type `dtype` constructed from reading bytes from the data file starting at `offset`. + """ + return numpy.frombuffer(self._bin_buffer, dtype=dtype, count=count, offset=offset) + + def __del__(self) -> None: + """Clean up the object.""" + if self._bin_buffer_mmap is not None: + self._bin_buffer_mmap._mmap.close() + del self._bin_buffer_mmap + + +class _FileBinReader(_BinReader): + """A _BinReader that reads from the data (.bin) file using a file pointer + + Args: + bin_path (str): bin_path (str): The path to the data (.bin) file. + """ + + def __init__(self, bin_path: str) -> None: + self._bin_path = bin_path + + def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndarray: + """Read bytes into a numpy array. + + Args: + dtype (Type[numpy.number]): Data-type of the returned array. + + count (int): Number of items to read. + + offset (int): Start reading from this offset (in bytes). + + Returns: + numpy.ndarray: An array with `count` items and data-type `dtype` constructed from reading bytes from the data file starting at `offset`. + """ + sequence = numpy.empty(count, dtype=dtype) + with open(self._bin_path, mode='rb', buffering=0) as bin_buffer_file: + bin_buffer_file.seek(offset) + bin_buffer_file.readinto(sequence) + return sequence + + +class _S3BinReader(_BinReader): + """A _BinReader that reads from the data (.bin) file from S3 + + Args: + bin_path (str): bin_path (str): The path to the data (.bin) file. + + bin_chunk_nbytes (int, optional): If not None, then maintain an in-memory cache to speed up calls to the `read` method. Furthermore, on a cache miss, download this number of bytes to refresh the cache. Otherwise (None), do not maintain an in-memory cache. A class that inherits from _BinReader may not implement caching in which case it should assert that `bin_chunk_nbytes` is None at initialization. + """ + + def __init__(self, bin_path: str, bin_chunk_nbytes: int) -> None: + assert bin_chunk_nbytes > 0 + self._client = boto3.client("s3") + self._s3_bucket, self._s3_key = parse_s3_path(bin_path) + self._cache = None + self._cache_bytes_start = None + self._cache_bytes_end = None + self._cache_nbytes = bin_chunk_nbytes + + def _extract_from_cache(self, offset: int, size: int) -> bytes: + """Extract `size` bytes starting at `offset` bytes into the cache""" + start = offset - self._cache_bytes_start + assert start >= 0 + end = start + size + assert end <= len(self._cache) + return self._cache[start:end] + + def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndarray: + """Read bytes into a numpy array. + + Let `size` be the `count` * `DType.size(dtype)`. If the requested span of bytes [`offset`, + `offset` + `size`) is covered by the in-memory cache maintained by this class, then this + function extracts the requested span from that cache and returns it. Otherwise, this + function first refreshes the cache and then extracts the requested span from the refreshed + cache and returns it. + + The cache is refreshed based on `offset` and `size`. In particular, we divide all the bytes + in an S3 object into blocks, where each block contains `bin_chunk_nbytes` bytes. We assign + each block an index starting from 0. We take the block with index (`offset` // + `bin_chunk_nbytes`) to refresh the cache. If this new block still does not cover the + requested span, we extend it just enough to include `offset` + `size`. + + Args: + dtype (Type[numpy.number]): Data-type of the returned array. + + count (int): Number of items to read. + + offset (int): Start reading from this offset (in bytes). + + Returns: + numpy.ndarray: An array with `count` items and data-type `dtype` constructed from reading bytes from the data file starting at `offset`. + """ + size = count * DType.size(dtype) + if ( + self._cache is not None + and offset >= self._cache_bytes_start + and offset + size <= self._cache_bytes_end + ): + return numpy.frombuffer(self._extract_from_cache(offset, size), dtype=dtype) + + bytes_start = (offset // self._cache_nbytes) * self._cache_nbytes + assert bytes_start >= 0 + assert offset >= bytes_start + bytes_end = max(bytes_start + self._cache_nbytes, offset + size) + assert bytes_end >= 1 + self._cache = self._client.get_object( + Bucket=self._s3_bucket, + Key=self._s3_key, + # Subtract 1, because the end of Range is inclusive. + Range=f'bytes={bytes_start}-{bytes_end-1}', + )['Body'].read() + self._cache_bytes_start = bytes_start + self._cache_bytes_end = bytes_end + return numpy.frombuffer(self._extract_from_cache(offset, size), dtype=dtype) + + def __del__(self) -> None: + """Clean up the object""" + self._client.close() + + +class IndexedDataset(torch.utils.data.Dataset): """The low-level interface dataset class Args: path_prefix (str): The index (.idx) and data (.bin) prefix - multimodal (bool, optional): Whether the dataset is multimodal. Defaults to False. + multimodal (bool): Whether the dataset is multimodal. Defaults to False. + + mmap (bool): Whether to mmap the .bin files. Defaults to True. + + s3_config (Optional[S3Config]): Supplied only for data stored on S3. IndexedDataset downloads the index (.idx) file to `s3_config.path_to_idx_cache` and streams data from the data (.bin) file in `s3_config.bin_chunk_nbytes` blocks. Note that `mmap` must be disabled for S3 data loading. Defaults to None. """ - def __init__(self, path_prefix: str, multimodal: bool = False) -> None: + def __init__( + self, + path_prefix: str, + multimodal: bool = False, + mmap: bool = True, + s3_config: Optional[S3Config] = None, + ) -> None: super().__init__() self.path_prefix = None self.multimodal = None + self.mmap = None + self.s3_config = None self.index = None - self.bin_buffer = None - self.bin_buffer_mmap = None + self.bin_reader = None + + if is_s3_path(path_prefix) and s3_config is not None: + idx_path = get_idx_path(path_prefix) + cache_idx_path = os.path.join(s3_config.path_to_idx_cache, os.path.basename(idx_path)) + maybe_download_file(idx_path, cache_idx_path) - self.initialize(path_prefix, multimodal) + self.initialize(path_prefix, multimodal, mmap, s3_config) - def initialize(self, path_prefix: str, multimodal: bool) -> None: + def initialize( + self, path_prefix: str, multimodal: bool, mmap: bool, s3_config: Optional[S3Config] + ) -> None: """Initialize the dataset - This method is called by MMapIndexedDataset.__init__ during object creation and by - MMapIndexedDataset.__setstate__ during un-puckling + This method is called by IndexedDataset.__init__ during object creation and by + IndexedDataset.__setstate__ during un-pickling Args: path_prefix (str): The index (.idx) and data (.bin) prefix multimodal (bool): Whether the dataset is multimodal + + mmap (bool): Whether to mmap the .bin file + + s3_config (Optional[S3Config]): See IndexedDataset docstring for details. """ + idx_path = get_idx_path(path_prefix) + bin_path = get_bin_path(path_prefix) + if s3_config is None: + assert os.path.exists(idx_path) and os.path.exists( + bin_path + ), f"One or both of the .idx and .bin files cannot be found at the path prefix {path_prefix}" self.path_prefix = path_prefix self.multimodal = multimodal - self.index = _IndexReader(get_idx_path(self.path_prefix), self.multimodal) - self.bin_buffer_mmap = numpy.memmap(get_bin_path(self.path_prefix), mode="r", order="C") - self.bin_buffer = memoryview(self.bin_buffer_mmap) + self.mmap = mmap + self.s3_config = s3_config + if mmap: + assert not s3_config + self.bin_reader = _MMapBinReader(bin_path) + elif s3_config: + assert not mmap + self.bin_reader = _S3BinReader(bin_path, s3_config.bin_chunk_nbytes) + idx_path = os.path.join( + s3_config.path_to_idx_cache, os.path.basename(get_idx_path(path_prefix)) + ) + else: + self.bin_reader = _FileBinReader(bin_path) + self.index = _IndexReader(idx_path, self.multimodal) - def __getstate__(self) -> Tuple[str, bool]: + def __getstate__(self) -> Tuple[str, bool, bool, Optional[S3Config]]: """Get the state during pickling Returns: - Tuple[str, bool]: The state tuple + Tuple[str, bool, bool, Optional[S3Config]]: The state tuple """ - return self.path_prefix, self.multimodal + return self.path_prefix, self.multimodal, self.mmap, self.s3_config - def __setstate__(self, state: Tuple[str, bool]) -> None: + def __setstate__(self, state: Tuple[str, bool, bool, Optional[S3Config]]) -> None: """Set the state during un-pickling Args: - state (Tuple[str, bool]): The state tuple + state (Tuple[str, bool, bool, Optional[S3Config]]): The state tuple """ - path_prefix, multimodal = state - self.initialize(path_prefix, multimodal) + path_prefix, multimodal, mmap, s3_config = state + self.initialize(path_prefix, multimodal, mmap, s3_config) def __del__(self) -> None: - """Clean up the object - """ - if self.bin_buffer_mmap is not None: - self.bin_buffer_mmap._mmap.close() - del self.bin_buffer_mmap + """Clean up the object""" + del self.bin_reader del self.index def __len__(self) -> int: @@ -415,16 +623,12 @@ def __getitem__( TypeError: When the index is of an unexpected type Returns: - Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and - modes at the index or index slice + Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and modes at the index or index slice """ if isinstance(idx, (int, numpy.integer)): sequence_pointer, sequence_length, sequence_mode = self.index[idx] - sequence = numpy.frombuffer( - self.bin_buffer, - dtype=self.index.dtype, - count=sequence_length, - offset=sequence_pointer, + sequence = self.bin_reader.read( + dtype=self.index.dtype, count=sequence_length, offset=sequence_pointer ) return (sequence, sequence_mode) if sequence_mode is not None else sequence elif isinstance(idx, slice): @@ -435,8 +639,7 @@ def __getitem__( sequence_modes = self.index.sequence_modes[idx] if self.multimodal else None sequence_offsets = list(accumulate(sequence_lengths)) sequences = numpy.split( - numpy.frombuffer( - self.bin_buffer, + self.bin_reader.read( dtype=self.index.dtype, count=sum(sequence_lengths), offset=self.index.sequence_pointers[start], @@ -452,13 +655,23 @@ def get(self, idx: int, offset: int = 0, length: Optional[int] = None) -> numpy. return a portion of the item. get(idx) is the same as [idx] but get() does not support slicing. + + Args: + idx (Union[int, numpy.integer]): The index into the dataset + + offset (int): The integer token offset in the sequence + + length (int): The number of tokens to grab from the sequence + + Returns: + Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and modes at the index """ sequence_pointer, sequence_length, sequence_mode = self.index[idx] if length is None: length = sequence_length - offset sequence_pointer += offset * DType.size(self.index.dtype) - sequence = numpy.frombuffer( - self.bin_buffer, dtype=self.index.dtype, count=length, offset=sequence_pointer + sequence = self.bin_reader.read( + dtype=self.index.dtype, count=length, offset=sequence_pointer ) return (sequence, sequence_mode) if sequence_mode is not None else sequence @@ -511,21 +724,26 @@ def sequence_modes(self) -> numpy.ndarray: @staticmethod def exists(path_prefix: str) -> bool: - """Return whether the MMapIndexedDataset exists on disk at the prefix + """Return whether the IndexedDataset exists on disk at the prefix Args: path_prefix (str): The prefix to the index (.idx) and data (.bin) files Returns: - bool: Whether the MMapIndexedDataset exists on disk at the prefix + bool: Whether the IndexedDataset exists on disk at the prefix """ + if is_s3_path(path_prefix): + s3_client = boto3.client("s3") + return object_exists(s3_client, get_idx_path(path_prefix)) and object_exists( + s3_client, get_bin_path(path_prefix) + ) return os.path.exists(get_idx_path(path_prefix)) and os.path.exists( get_bin_path(path_prefix) ) -class MMapIndexedDatasetBuilder(object): - """Builder class for the MMapIndexedDataset class +class IndexedDatasetBuilder(object): + """Builder class for the IndexedDataset class Args: bin_path (str): The path to the data (.bin) file @@ -567,9 +785,10 @@ def add_document( Args: tensor (torch.Tensor): The document to add + lengths (List[int]): The lengths of each item in the document - modes (Optional[List[int]], optional): The modes for each item in the document. - Defaults to None. + + modes (Optional[List[int]], optional): The modes for each item in the document. Defaults to None. """ np_array = numpy.array(tensor, dtype=self.dtype) self.data_file.write(np_array.tobytes(order="C")) @@ -579,12 +798,11 @@ def add_document( self.sequence_modes.extend(modes if modes is not None else [0] * lengths) def end_document(self) -> None: - """Finalize the document, for use with MMapIndexedDatasetBuilder.add_item - """ + """Finalize the document, for use with IndexedDatasetBuilder.add_item""" self.document_indices.append(len(self.sequence_lengths)) def add_index(self, path_prefix: str) -> None: - """Add an entire MMapIndexedDataset to the dataset + """Add an entire IndexedDataset to the dataset Args: path_prefix (str): The index (.idx) and data (.bin) prefix diff --git a/megatron/core/datasets/masked_dataset.py b/megatron/core/datasets/masked_dataset.py new file mode 100644 index 0000000000..c2a02ebaea --- /dev/null +++ b/megatron/core/datasets/masked_dataset.py @@ -0,0 +1,425 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import logging +import os +import time +from abc import abstractmethod +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import numpy +import torch + +from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig +from megatron.core.datasets.indexed_dataset import IndexedDataset +from megatron.core.datasets.megatron_dataset import MegatronDataset +from megatron.core.datasets.utils import Split +from megatron.core.utils import log_single_rank + +logger = logging.getLogger(__name__) + + +@dataclass +class MaskedWordPieceDatasetConfig(BlendedMegatronDatasetConfig): + """Configuration object for Megatron Core Masked WordPiece datasets""" + + masking_probability: float = None + """The probability we mask a candidate N-gram""" + + short_sequence_probability: float = None + """The probability we return a sequence shorter than the target sequence length""" + + masking_max_ngram: int = None + """The maximum length N-gram to consider masking or permuting""" + + masking_do_full_word: bool = None + """Whether we mask the whole word or its component parts""" + + masking_do_permutation: bool = None + """Whether we shuffle a subset of candidate N-grams in addition""" + + masking_use_longer_ngrams: bool = None + """Whether to favor longer N-grams over shorter N-grams""" + + masking_use_geometric_distribution: bool = None + """Whether to draw the size of the N-gram from a geometric distribution according to SpanBERT + https://arxiv.org/abs/1907.10529 (Section 3.1) + """ + + def __post_init__(self) -> None: + """Do asserts and set fields post init""" + super().__post_init__() + + assert self.tokenizer is not None + + assert self.masking_probability is not None + assert self.short_sequence_probability is not None + assert self.masking_max_ngram is not None + assert self.masking_do_full_word is not None + assert self.masking_do_permutation is not None + assert self.masking_use_longer_ngrams is not None + assert self.masking_use_geometric_distribution is not None + + assert self.masking_probability > 0 and self.masking_probability < 1.0 + assert self.short_sequence_probability >= 0 and self.short_sequence_probability <= 1.0 + assert self.masking_max_ngram > 0 + assert not (self.masking_use_geometric_distribution and self.masking_do_permutation) + + if self.masking_use_geometric_distribution and self.masking_use_longer_ngrams: + log_single_rank( + logger, + logging.WARNING, + "The use of a geometric distribution overrides the default distribution", + ) + + +class MaskedWordPieceDataset(MegatronDataset): + """The semi-abstract base class for masked WordPiece datasets + + This implementation makes the rigid assumption that all inheritor datasets are built upon the + IndexedDataset class. This assumption may be pushed down to the inheritors in future if + necessary. + + NB: WordPiece tokenization prepends a double hash "##" to all tokens/pieces in a word, save the + first token/piece. + + Args: + indexed_dataset (IndexedDataset): The IndexedDataset around which to build the + MegatronDataset + + dataset_path (str): The real path on disk to the dataset, for bookkeeping + + indexed_indices (numpy.ndarray): The set of the documents indices to expose + + num_samples (Optional[int]): The number of samples to draw from the indexed dataset. + When None, build as many samples as correspond to one epoch. + + index_split (Split): The indexed_indices Split + + config (MaskedWordPieceDatasetConfig): The config + """ + + def __init__( + self, + indexed_dataset: IndexedDataset, + dataset_path: str, + indexed_indices: numpy.ndarray, + num_samples: Optional[int], + index_split: Split, + config: MaskedWordPieceDatasetConfig, + ) -> None: + super().__init__( + indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config + ) + + @staticmethod + def numel_low_level_dataset(low_level_dataset: IndexedDataset) -> int: + return low_level_dataset.document_indices.shape[0] - 1 + + @staticmethod + def build_low_level_dataset( + dataset_path: str, config: MaskedWordPieceDatasetConfig + ) -> IndexedDataset: + return IndexedDataset(dataset_path) + + @staticmethod + def _key_config_attributes() -> List[str]: + """Inherited method implementation + + Returns: + List[str]: The key config attributes + """ + return super(MaskedWordPieceDataset, MaskedWordPieceDataset)._key_config_attributes() + [ + "masking_probability", + "short_sequence_probability", + "masking_max_ngram", + "masking_do_full_word", + "masking_do_permutation", + "masking_use_longer_ngrams", + "masking_use_geometric_distribution", + ] + + def __len__(self) -> int: + return self.sample_index.shape[0] + + def _build_sample_index( + self, sequence_length: int, min_sentences_per_sample: int + ) -> numpy.ndarray: + path_to_cache = self.config.path_to_cache + if path_to_cache is None: + path_to_cache = os.path.join( + self.dataset.path_prefix, "cache", f"{type(self).__name__}_indices" + ) + + get_path_to = lambda suffix: os.path.join( + path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}" + ) + path_to_description = get_path_to("description.txt") + path_to_sample_index = get_path_to("sample_index.npy") + cache_hit = all(map(os.path.isfile, [path_to_description, path_to_sample_index])) + + if self.num_samples is not None: + num_epochs = numpy.iinfo(numpy.int32).max - 1 + else: + num_epochs = 1 + + if not cache_hit and torch.distributed.get_rank() == 0: + log_single_rank( + logger, + logging.INFO, + f"Build and save the {type(self).__name__} {self.index_split.name} indices", + ) + self.built_anew_on_cache_miss = True + + os.makedirs(path_to_cache, exist_ok=True) + + # Write the description + with open(path_to_description, "wt") as writer: + writer.write(self.unique_description) + + # Build the sample index + log_single_rank( + logger, + logging.INFO, + f"\tBuild and save the sample index to {os.path.basename(path_to_sample_index)}", + ) + t_beg = time.time() + from megatron.core.datasets import helpers + + # Add +1 for access to document upper bound + indices = numpy.append(self.indices, self.indices[-1] + 1) + + sample_index = helpers.build_mapping( + self.dataset.document_indices[indices], + self.dataset.sequence_lengths, + num_epochs, + self.num_samples, + sequence_length, + self.config.short_sequence_probability, + self.config.random_seed, + False, + min_sentences_per_sample, + ) + numpy.save(path_to_sample_index, sample_index, allow_pickle=True) + t_end = time.time() + log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") + + log_single_rank( + logger, logging.INFO, f"> total number of samples: {sample_index.shape[0]}" + ) + log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}") + + return sample_index + + log_single_rank( + logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices" + ) + + log_single_rank( + logger, + logging.INFO, + f"\tLoad the sample index from {os.path.basename(path_to_sample_index)}", + ) + t_beg = time.time() + sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode="r") + t_end = time.time() + log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") + + return sample_index + + def _create_masked_lm_predictions( + self, + token_ids: List[int], + target_sequence_length: int, + numpy_random_state: numpy.random.RandomState, + ) -> Tuple[List[int], List[int], List[int], List[int], List[Tuple[List[int], List[int]]]]: + """Creates the predictions for the masked LM objective + + Args: + token_ids (List[int]): The token ids + target_sequence_length (int): The target sequence length + numpy_random_state (numpy.random.RandomState): The NumPy random state + + Returns: + Tuple[List[int], List[int], List[int], List[int], List[Tuple[List[int], List[int]]]]: + 1. masked_token_ids -> The masked sequence + 2. masked_positions -> The indices for the masked token ids + 3. masked_labels -> The original token ids for the masked token ids + 4. boundaries -> The sentence and word boundaries for the sequence + 4. masked_spans -> The masked positions and labels with N-gram info intact + """ + # Build the token sentence and word boundaries and the masking candidates + # e.g. [cls, id, ##id, ##id, id, ##id, sep, id, ##id, sep] + # -> boundaries: [1, 1, 0, 0, 1, 0, 1, 1, 0, 1] + # -> candidates with whole word masking: [[1, 2, 3], [4, 5], [7, 8]] + # -> candidates sans whole word masking: [[1], [2], [3], [4], [5], [7], [8]] + boundaries = [] + candidates = [] + for i, token_id in enumerate(token_ids): + if token_id == self.config.tokenizer.cls or token_id == self.config.tokenizer.sep: + boundaries.append(1) + else: + if not self.config.tokenizer.inv_vocab[token_id].startswith("##"): + boundaries.append(1) + candidates.append([i]) + else: + boundaries.append(0) + if self.config.masking_do_full_word and len(candidates) > 0: + candidates[-1].append(i) + else: + candidates.append([i]) + + n_maskings = min( + self.config.masking_probability * target_sequence_length, + max(1, int(round(len(token_ids) * self.config.masking_probability))), + ) + + ngram_nvals = numpy.arange(self.config.masking_max_ngram, dtype=numpy.int64) + 1 + + # By default, the N-gram probabilities are inversely proportional to N + # e.g. N = 3 + # -> P = array([0.54545455, 0.27272727, 0.18181818]) + nprobs = 1.0 / ngram_nvals + nprobs = nprobs / nprobs.sum(keepdims=True) + if self.config.masking_use_longer_ngrams: + nprobs = nprobs[::-1] + + # Create a nested list of depth 3 + # layer 1: the candidate dimension + # layer 2: the N-gram dimension + # layer 3: the token dimension + candidate_ngrams = [ + [candidates[idx : idx + n] for n in ngram_nvals] for idx in range(len(candidates)) + ] + numpy_random_state.shuffle(candidate_ngrams) + + masked_token_ids = list(token_ids) + masked_positions_and_labels = [] + masked_spans = [] + masked_indices = set() + for candidate_idx in range(len(candidate_ngrams)): + n_ngrams = len(candidate_ngrams[candidate_idx]) + + # Stop when we hit our desired number of maskings + if len(masked_positions_and_labels) >= n_maskings: + break + + # Do nothing for candidates with no ngrams + if not candidate_ngrams[candidate_idx]: + continue + + # Choose the initial value of N + if self.config.masking_use_geometric_distribution: + # Sample N from a geometric distribution with p = 0.2 and clip + # i.e. SpanBERT + # -> https://arxiv.org/abs/1907.10529 (Section 3.1) + p = 0.2 + n = min(numpy_random_state.geometric(p), self.config.masking_max_ngram) + else: + p = nprobs[:n_ngrams] / nprobs[:n_ngrams].sum(keepdims=True) + n = numpy_random_state.choice(ngram_nvals[:n_ngrams], p=p) + + while True: + ngram_indices = sum(candidate_ngrams[candidate_idx][n - 1], []) + n = n - 1 + # Success: masking this N-gram puts us below the desired number of maskings + if n_maskings >= len(masked_positions_and_labels) + len(ngram_indices): + skip_candidate = False + break + # Failure: no N-grams remain for this candidate + if n == 0: + skip_candidate = True + break + + # Do nothing for candidates whose 1-gram is too long + if skip_candidate: + continue + + # Do nothing for candidate indices which have already been masked + if any(map(lambda idx: idx in masked_indices, ngram_indices)): + continue + + # Mask the tokens and record their original positions and values + for index in ngram_indices: + masked_indices.add(index) + mask = self._get_token_mask(numpy_random_state) + if mask is None: + masked_token_ids[index] = token_ids[index] + else: + masked_token_ids[index] = mask + masked_positions_and_labels.append((index, token_ids[index])) + + masked_spans.append((ngram_indices, [token_ids[index] for index in ngram_indices])) + + assert len(masked_positions_and_labels) <= n_maskings + + numpy_random_state.shuffle(candidate_ngrams) + + if self.config.masking_do_permutation: + + n_swappings = n_maskings + + permuted_indices = set() + for candidate_idx in range(len(candidate_ngrams)): + n_ngrams = len(candidate_ngrams[candidate_idx]) + + if len(permuted_indices) >= n_swappings: + break + + # Do nothing for candidates with no ngrams + if not candidate_ngrams[candidate_idx]: + continue + + p = nprobs[:n_ngrams] / nprobs[:n_ngrams].sum(keepdims=True) + n = numpy.random.choice(ngram_nvals[:n_ngrams], p=p) + + while True: + ngram_indices = sum(candidate_ngrams[candidate_idx][n - 1], []) + n = n - 1 + # Success: swapping this N-gram puts us below the desired number of swappings + if n_swappings >= len(permuted_indices) + len(ngram_indices): + skip_candidate = False + break + # Failure: no N-grams remain for this candidate + if n == 0: + skip_candidate = True + break + + # Do nothing for candidates whose 1-gram is too long + if skip_candidate: + continue + + # Do nothing for candidate indices which have already been masked or permuted + if any( + map(lambda idx: idx in masked_indices or idx in permuted_indices, ngram_indices) + ): + continue + + for index in ngram_indices: + permuted_indices.add(index) + + assert len(permuted_indices) <= n_swappings + + permuted_indices = sorted(permuted_indices) + permuted_indices_copy = list(permuted_indices) + numpy_random_state.shuffle(permuted_indices_copy) + masked_token_ids_copy = list(masked_token_ids) + + for idx, idx_copy in zip(permuted_indices, permuted_indices_copy): + masked_token_ids[idx] = masked_token_ids_copy[idx_copy] + masked_positions_and_labels.append((idx, masked_token_ids_copy[idx])) + + masked_positions_and_labels = sorted(masked_positions_and_labels, key=lambda x: x[0]) + masked_positions = [] + masked_labels = [] + for position, label in masked_positions_and_labels: + masked_positions.append(position) + masked_labels.append(label) + + masked_spans = sorted(masked_spans, key=lambda x: x[0][0]) + + return masked_token_ids, masked_positions, masked_labels, boundaries, masked_spans + + @abstractmethod + def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> Optional[int]: + pass diff --git a/megatron/core/datasets/megatron_dataset.py b/megatron/core/datasets/megatron_dataset.py index d75a645509..15a9a53328 100644 --- a/megatron/core/datasets/megatron_dataset.py +++ b/megatron/core/datasets/megatron_dataset.py @@ -2,134 +2,138 @@ import hashlib import json -from abc import ABC, abstractmethod, abstractstaticmethod +from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Dict, List +from typing import Any, Dict, Iterable, List, Optional, Union import numpy import torch from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig -from megatron.core.datasets.indexed_dataset import MMapIndexedDataset +from megatron.core.datasets.indexed_dataset import IndexedDataset from megatron.core.datasets.utils import Split +LowLevelDataset = Union[IndexedDataset, Iterable] + class MegatronDataset(ABC, torch.utils.data.Dataset): - """The wrapper class from which dataset classes should inherit e.g. GPTDataset + """The highest level wrapper class from which all dataset classes should inherit Args: - indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the - MegatronDataset + dataset (LowLevelDataset): The dataset around which to build the MegatronDataset + + dataset_path (Optional[str]): The real path on disk to the dataset, for bookkeeping - indexed_indices (numpy.ndarray): The set of the documents indices to expose + indices (numpy.ndarray): The set of the documents indices to expose - num_samples (int): The number of samples to draw from the indexed dataset + num_samples (Optional[int]): The minimum number of samples to build from the indexed dataset. When None, build as many samples as correspond to one epoch. - index_split (Split): The indexed_indices Split + index_split (Split): The indices Split - config (BlendedMegatronDatasetConfig): The container for all config sourced parameters + config (BlendedMegatronDatasetConfig): The config """ def __init__( self, - indexed_dataset: MMapIndexedDataset, - indexed_indices: numpy.ndarray, - num_samples: int, + dataset: LowLevelDataset, + dataset_path: Optional[str], + indices: numpy.ndarray, + num_samples: Optional[int], index_split: Split, config: BlendedMegatronDatasetConfig, ) -> None: - assert indexed_indices.size > 0 - assert num_samples > 0 - assert self.is_multimodal() == indexed_dataset.multimodal - assert self.is_split_by_sequence() != self.is_split_by_document() - - self.indexed_dataset = indexed_dataset - self.indexed_indices = indexed_indices + self.dataset = dataset + self.dataset_path = dataset_path + self.indices = indices self.num_samples = num_samples self.index_split = index_split self.config = config self.unique_identifiers = OrderedDict() + self.unique_identifiers["class"] = type(self).__name__ - self.unique_identifiers["path_prefix"] = self.indexed_dataset.path_prefix + self.unique_identifiers["dataset_path"] = self.dataset_path self.unique_identifiers["num_samples"] = self.num_samples self.unique_identifiers["index_split"] = self.index_split.name for attr in self._key_config_attributes(): self.unique_identifiers[attr] = getattr(self.config, attr) - self.unique_description = json.dumps(self.unique_identifiers, indent=4) + self.unique_description = json.dumps( + self.unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers + ) self.unique_description_hash = hashlib.md5( self.unique_description.encode("utf-8") ).hexdigest() - self._finalize() + self.built_anew_on_cache_miss = False - @abstractmethod - def _finalize(self) -> None: - """Build the dataset and assert any subclass-specific conditions - """ - pass + @staticmethod + def numel_low_level_dataset(low_level_dataset: LowLevelDataset) -> int: + """Return the number of elements in the underlying low level dataset for the purpose of + segregating the train/valid/test split indices - @abstractmethod - def __len__(self) -> int: - """Return the length of the dataset + It may be that the low level dataset can be split any number of ways, depending on the mid + level dataset it supports, which is why we define the "number of elements" function + separately from the __len__ function here in the mid level dataset class + + Args: + low_level_dataset (LowLevelDataset): The underlying low level dataset Returns: - int: See abstract implementation + int: The number of elements in the underlying low level dataset """ - pass + raise NotImplementedError - @abstractmethod - def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: - """Return from the dataset + @staticmethod + def build_low_level_dataset( + dataset_path: str, config: BlendedMegatronDatasetConfig + ) -> LowLevelDataset: + """Build the low level dataset via a function to be called from within + BlendedMegatronDatasetBuilder.build_generic_dataset - Args: - idx (int): The index into the dataset + It may be that the low level dataset spans any subset of train/valid/test splits, which is + why we define a static "build" function separately from the constructor in the mid level + dataset class - Returns: - Dict[str, numpy.ndarray]: See abstract implementation - """ - pass + Args: + dataset_path (str): The real path on disk to the dataset - @abstractstaticmethod - def is_multimodal() -> bool: - """Return True if the inheritor class and its internal MMapIndexedDataset are multimodal + config (BlendedMegatronDatasetConfig): The dataset config Returns: - bool: See abstract implementation + LowLevelDataset: The low level dataset """ - pass + raise NotImplementedError - @abstractstaticmethod - def is_split_by_sequence() -> bool: - """Return whether the dataset is split by sequence + @staticmethod + def _key_config_attributes() -> List[str]: + """Return all config attributes which contribute to uniquely identifying the dataset. - For example, the GPT train/valid/test split is document agnostic + These attributes will be used to build a uniquely identifying string and MD5 hash which + will be used to cache/load dataset resources from run to run. Returns: - bool: See abstract implementation + List[str]: The key config attributes """ - pass - - @classmethod - def is_split_by_document(cls) -> bool: - """Return whether the dataset is split by document + return ["random_seed", "sequence_length", "split", "split_matrix", "tokenizer"] - For example, the BERT train/valid/test split is document aware + @abstractmethod + def __len__(self) -> int: + """Return the length of the dataset Returns: - bool: The negation of cls.is_split_by_sequence + int: See abstract implementation """ - return not cls.is_split_by_sequence() + pass - @staticmethod - def _key_config_attributes() -> List[str]: - """Return all config attributes which contribute to uniquely identifying the dataset. + @abstractmethod + def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, numpy.ndarray]]: + """Return from the dataset - These attributes will be used to build a uniquely identifying string and MD5 hash which - will be used to cache/load the dataset from run to run. + Args: + idx (int): The index into the dataset Returns: - List[str]: The key config attributes + Dict[str, Union[torch.Tensor, numpy.ndarray]]: See abstract implementation """ - return ["split", "random_seed", "sequence_length"] + pass diff --git a/megatron/core/datasets/megatron_tokenizer.py b/megatron/core/datasets/megatron_tokenizer.py new file mode 100644 index 0000000000..84f3546cf3 --- /dev/null +++ b/megatron/core/datasets/megatron_tokenizer.py @@ -0,0 +1,154 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import json +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Any + +import numpy + + +class MegatronTokenizer(ABC): + """Abstract class for tokenizer + + Absent a config or class-specific tracking of which objects are uniquely identifying, we must + include all key word arguments as unique identifiers + + Args: + tokenizer_paths (Tuple[str]): All tokenizer source paths or prefixes + + tokenizer_options (Dict[str, Any]): All tokenizer options + """ + + def __init__(self, *tokenizer_paths: str, **tokenizer_options: Any): + + self.unique_identifiers = OrderedDict() + self.unique_identifiers["class"] = type(self).__name__ + self.unique_identifiers["tokenizer_path"] = list(tokenizer_paths) + for option in tokenizer_options: + self.unique_identifiers[option] = str(tokenizer_options[option]) + + self.unique_description = json.dumps(self.unique_identifiers, indent=4) + + super().__init__() + + @abstractmethod + def tokenize(self, text: str) -> numpy.ndarray: + """Convert text to embedding ids + + Args: + text (str): The text to convert + + Returns: + numpy.ndarray: The converted embedding ids + """ + pass + + def detokenize(self, ids: numpy.ndarray) -> str: + """Convert embedding ids to text + + Args: + ids (numpy.ndarray): The ids to convert + + Returns: + str: The converted text + + Raises: + NotImplementedError: Non-abstract, optional method + """ + raise NotImplementedError("{} has no method 'detokenize'".format(type(self).__name__)) + + def offsets(self, ids: list[int], text: str) -> list[int]: + """Convert embedding ids to text offsets + + Args: + ids (list[int]): The ids to convert + text (str): The text to convert + + Returns: + list[int]: The converted offsets + + Raises: + NotImplementedError: Non-abstract, optional method + """ + raise NotImplementedError("{} has no method 'offsets'".format(type(self).__name__)) + + @property + @abstractmethod + def vocab(self): + """Dictionary from vocab text token to id token""" + pass + + @property + @abstractmethod + def inv_vocab(self): + """Dictionary from vocab id token to text token""" + pass + + @property + @abstractmethod + def vocab_size(self): + """The vocabulary size""" + pass + + @property + def cls(self): + """The CLS token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'cls'".format(type(self).__name__)) + + @property + def sep(self): + """The SEP token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'sep'".format(type(self).__name__)) + + @property + def pad(self): + """The PAD token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'pad'".format(type(self).__name__)) + + @property + def eod(self): + """The EOD token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'eod'".format(type(self).__name__)) + + @property + def bos(self): + """The BOS token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'bos'".format(type(self).__name__)) + + @property + def eos(self): + """The EOS token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'eos'".format(type(self).__name__)) + + @property + def mask(self): + """The MASK token id + + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'mask'".format(type(self).__name__)) diff --git a/megatron/core/datasets/multimodal_dataset.py b/megatron/core/datasets/multimodal_dataset.py new file mode 100644 index 0000000000..0a3e93a15b --- /dev/null +++ b/megatron/core/datasets/multimodal_dataset.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass +from typing import Callable, Dict + +import torch + +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset + + +@dataclass +class MultimodalDatasetConfig(GPTDatasetConfig): + """Configuration object for Megatron Core Multimodal datasets. + + Note: This is unused at the moment and may be missing features. Follow-up changes will use this. + """ + + image_h: int = None + """Image height.""" + + image_w: int = None + """Image width.""" + + # Function to preprocess the data sample to a format expected by a specific model. By default, do nothing. + preprocess_func: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = lambda x: x + """Optional function to preprocess data samples for a specific model.""" + + def __post_init__(self) -> None: + super().__post_init__() + + assert self.image_h is not None + assert self.image_w is not None + + +class MockMultimodalDataset(MockGPTDataset): + """Mock multimodal dataset. + + + This is unused at the moment and may be missing features. Follow-up changes will use this. + """ + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Return a sample that contains a dummy image, text sequence and the associated labels and cost and attention masks. + + Args: + idx (int): The integer seed for mock data generation. + + Returns: + Dict[str, torch.Tensor]: The mock data. + """ + # Get a text sample. + sample = super().__getitem__(idx) + + # Add mock input image. + sample["image"] = torch.zeros( + (3, self.config.image_h, self.config.image_w), dtype=torch.float32 + ) + + # Run optional data preprocessing. + preprocess_func = self.config.preprocess_func + + return preprocess_func(sample) diff --git a/megatron/core/datasets/readme.md b/megatron/core/datasets/readme.md index 77d1e5862f..12ade943b5 100644 --- a/megatron/core/datasets/readme.md +++ b/megatron/core/datasets/readme.md @@ -4,18 +4,18 @@ Data preprocessing is built around the following classes: -1. `MMapIndexedDatasetBuilder` -2. `MMapIndexedDataset` +1. `IndexedDatasetBuilder` +2. `IndexedDataset` At the moment, an end-to-end data preprocessing implementation is left to the user. See the class docstring(s) for more details. -#### MMapIndexedDatasetBuilder +#### IndexedDatasetBuilder -The `MMapIndexedDatasetBuilder` is capable of building and merging `MMapIndexedDataset` instances. +The `IndexedDatasetBuilder` is capable of building and merging `IndexedDataset` instances. -#### MMapIndexedDataset +#### IndexedDataset -The `MMapIndexedDataset` class is the lowest-level data interface in Megatron Core. Internally, an `MMapIndexedDataset` instance references two binaries: the data file (`.bin`) contains document/sequence data and the index file (`.idx`) contains document/sequence metadata. +The `IndexedDataset` class is the lowest-level data interface in Megatron Core. Internally, an `IndexedDataset` instance references two binaries: the data file (`.bin`) contains document/sequence data and the index file (`.idx`) contains document/sequence metadata. The index file stores dataset-level metadata first: - The index header, for backward compatibility @@ -36,7 +36,7 @@ Building the data loaders is a distributed-aware process built around the follow 1. `BlendedMegatronDatasetConfig` 2. `BlendedMegatronDatasetBuilder` -3. `MMapIndexedDataset` +3. `IndexedDataset` 3. `MegatronDataset` 4. `BlendedDataset` @@ -54,16 +54,16 @@ The `BlendedMegatronDatasetBuilder` class builds the highest-level data interfac **NB:** All ranks should attempt to build the dataset via the `BlendedMegatronDatasetBuilder` or the program will hang. Which ranks follow through on their attempts can be controlled via the `BlendedMegatronDatasetConfig`. -#### MMapIndexedDataset +#### IndexedDataset -The `MMapIndexedDataset` class is the lowest-level data interface in Megatron Core. +The `IndexedDataset` class is the lowest-level data interface in Megatron Core. -The `MMapIndexedDataset` should already exist on disk before attempting to build any of the high-level data interfaces. +The `IndexedDataset` should already exist on disk before attempting to build any of the high-level data interfaces. #### MegatronDataset (extendable) -The `MegatronDataset` abstract class is a high-level data interface in Megatron Core. It is an abstraction built upon the `MMapIndexedDataset`. +The `MegatronDataset` abstract class is a high-level data interface in Megatron Core. It is an abstraction built upon the `IndexedDataset`. Different training/inference regimes will require different extensions e.g. the `GPTDataset` @@ -77,7 +77,7 @@ The `BlendedDataset` is only necessary when a blend multiple data distributions, ### GPTDataset -The `GPTDataset` is parameterized by the following variables: the underlying `MMapIndexedDataset` instance `indexed_dataset`, the split indices `indexed_indices` (the congituous subset of document or sequence indices used for training, validation, and testing), the number of samples `N`, the sequence length `S`, and the random seed `R`. +The `GPTDataset` is parameterized by the following variables: the underlying `IndexedDataset` instance `indexed_dataset`, the split indices `indexed_indices` (the congituous subset of document or sequence indices used for training, validation, and testing), the number of samples `N`, the sequence length `S`, and the random seed `R`. The `GPTDataset` creates three index mappings to facilitate lookup: (1) the document index, (2) the sample index, and (3) the shuffle index. diff --git a/megatron/core/datasets/retro/__init__.py b/megatron/core/datasets/retro/__init__.py new file mode 100644 index 0000000000..7ce970c6e9 --- /dev/null +++ b/megatron/core/datasets/retro/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from .config import RetroGPTChunkDatasets +from .query.multi_split_gpt_dataset import MultiSplitGPTDataset, MultiSplitGPTDatasetConfig +from .query.retro_dataset import get_retro_datasets diff --git a/megatron/core/datasets/retro/config/__init__.py b/megatron/core/datasets/retro/config/__init__.py new file mode 100644 index 0000000000..3635bedb3f --- /dev/null +++ b/megatron/core/datasets/retro/config/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Exports: + + - Embedder: Base class for all Bert embedders. + - RetroBertEmbedders: Container class for in-memory and on-disk embedders. + - RetroPreprocessingConfig: Configuration class for all of Retro preprocessing. + - RetroGPTChunkDatasets: Container class for train, valid, and test datasets. + - RetroTokenizers: Container class for GPT and Bert tokenizers. +""" + +from .bert_embedders import Embedder, RetroBertEmbedders +from .config import RetroPreprocessingConfig +from .gpt_chunk_datasets import RetroGPTChunkDatasets +from .tokenizers import RetroTokenizers diff --git a/megatron/core/datasets/retro/config/bert_embedders.py b/megatron/core/datasets/retro/config/bert_embedders.py new file mode 100644 index 0000000000..8f3fe85c4a --- /dev/null +++ b/megatron/core/datasets/retro/config/bert_embedders.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Container dataclass for holding both in-memory and on-disk Bert embedders.""" + +import abc +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + + +class Embedder(abc.ABC): + """Base class for all Bert embedders. + + All embedders should be able to embed either an entire text dataset (to a 2D + numpy array), or a single text string (to a 1D numpy array). + """ + + @abc.abstractmethod + def embed_text_dataset(self, text_dataset: torch.utils.data.Dataset) -> np.ndarray: + """Embed a text dataset. + + Args: + text_dataset (torch.utils.data.Dataset): Text dataset to embed. Each sample of the text dataset should output a dict with a key 'text' and a string value. + + Returns: + A 2D ndarray with shape (len(text_dataset), dimension(embedder)). + """ + + @abc.abstractmethod + def embed_text(self, text: str) -> np.ndarray: + """Embed a simple string of text. + + Args: + text (str): A single text sample. + + Returns: + A 1D ndarray with shape (dimensions(embedder),). + """ + + +@dataclass +class RetroBertEmbedders: + """Container dataclass for in-memory and on-disk Bert embedders.""" + + disk: Embedder + mem: Embedder diff --git a/megatron/core/datasets/retro/config/config.py b/megatron/core/datasets/retro/config/config.py new file mode 100644 index 0000000000..ac9ca84124 --- /dev/null +++ b/megatron/core/datasets/retro/config/config.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Retro preprocessing config.""" + +from dataclasses import dataclass + +from megatron.core.transformer import TransformerConfig + +from .bert_embedders import RetroBertEmbedders +from .gpt_chunk_datasets import RetroGPTChunkDatasets +from .tokenizers import RetroTokenizers + + +@dataclass +class RetroPreprocessingConfig(TransformerConfig): + """Configuration object for Retro preprocessing. + + *Note* : Arguments prefixed with '--retro-gpt-*' or '--retro-bert-*' are + included and named as such to more easily handle managing both models + running at the same time. Megatron is not optimized to run two models at + once, so this naming convention makes it clearer. + + Args: + + retro_project_dir (str): Retro project directory, which contains the preprocessed data for for pretraining. This directory is built during preprocessing (see tools/retro/README.md), and contains subdirectories for the chunk database and pretraining neighbors. + retro_tasks (str): Comma-separated list of tasks to run. Run entire preprocesing pipeline by using '--retro-tasks build'. Alternatively, run individual stages with tasks (in this order) 'db-build', 'index-build', or 'query-pretraining-neighbors'. For example, '--retro-tasks db-build,index-build,query-pretraining-neighbors' is equivalent to '--retro-tasks build'; or the argument can contain a subset of these tasks. Stages must always be run in the correct order (listed above). + retro_task_validate (float): If defined, validate a randomly sampled subset of the existing results of the given task. Each task implements a 'validate' method that is responsible for sampling a `retro_task_validate` fraction of the existing results, and then checking for bitwise equality with the current code base. (E.g., `--retro-task-validate 0.01`.) + retro_block_size (int): Number of chunks to process at a time when generating Bert embeddings and querying the search index. Partial results for each block are generally saved to disk in separate files. + retro_doc_block_size (int): Number of documents to processe at time when processing token datasets into chunk databases. The partial chunk database for each block is saved into a separate file. + retro_gpt_seed (int): Random seed used for python, numpy, pytorch, and cuda. + retro_gpt_data_path (str): Path to the training dataset. Accepted format: 1) a single data path, 2) multiple datasets in the form: dataset1-weight dataset1-path dataset2-weight dataset2-path ... It is used with --split when a single dataset used for all three: train, valid and test. It is exclusive to the other --*-data-path args. + retro_gpt_data_cache_path (str): Path to a directory to hold cached index files. + retro_gpt_split (str): Comma-separated list of proportions for training, validation, and test split. For example the split `90,5,5` will use 90%% of data for training, 5%% for validation and 5%% for test. + retro_gpt_train_samples (int): Total number of samples to train over all training runs. + retro_gpt_eval_interval (int): GPT evaluation interval. + retro_gpt_eval_iters (int): GPT evaluation iterations. + retro_gpt_tokenizer_type (str): GPT tokenizer type. + retro_gpt_tokenizer_model (str): GPT tokenizer model file. + retro_gpt_vocab_file (str): GPT vocab file. + retro_gpt_merge_file (str): GPT merge file. + retro_gpt_seq_length (int): GPT sequence length. + retro_gpt_global_batch_size (int): GPT global batch size. + retro_gpt_chunk_length (int): GPT chunk length. + retro_bert_tokenizer_type (str): Bert tokenizer type (for when using '--bert-embedder-type megatron'). + retro_bert_vocab_file (str): Bert vocab file. + retro_bert_batch_size (int): Micro-batch size for processing Bert embeddings. + retro_bert_max_chunk_length (int): Maximum sequence length for Bert embeddings. (Named 'chunk' here in reference to these Bert sequences being converted from GPT chunks.) + retro_index_type (str): A 'faiss-base' index is a simple, un-optimized wrapper around a Faiss index. A 'faiss-par-add' index optimizes the 'add()' method by making it multi-node and multi-process, but with bit-wise equivalent results. + retro_index_str (str): Index string used for calling faiss.index_factory(). For example, 'IVF262144_HNSW32,Flat' or 'OPQ32_256,IVF4194304_HNSW32,PQ32'. + retro_index_ntrain (int): Number of database chunks to use for training the index. This value must be less or equal to the total number of chunks in the database. + retro_index_train_load_fraction (float): Fraction of sampled chunks to use for training the index. Useful when our total sampled embeddings use too much memory; lowering the load fraction is less costly than re-embedding a new sampled dataset from scratch. + retro_index_add_load_fraction (float): Fraction of database chunks to use for adding to the index. Useful when our total index size would use too much memory; lowering the load fraction is less costly than re-designing our token datasets. + retro_index_delete_training_embeddings (bool): Delete training embeddings for the search index. Useful for debugging. + retro_index_delete_added_codes (bool): Delete added codes for the search index. Useful for debugging. + retro_query_ef_search (int): Index ef-search parameter for Hierarchical Navigable Small Worlds (HNSW) during querying. + retro_query_nprobe (int): Index nprobe parameter for Inverted File (IVF) during querying. + retro_query_num_neighbors_query (int): Number of neighbors to retrieve when calling index.search(). + retro_query_num_neighbors_save (int): Number of neighbors to save to disk after the index's returned neighbors. If longer than target value, neighbors truncated; and if shorter than target value, neighbors are padded with -1's. + retro_bert_embedders (RetroBertEmbedders): Set of Bert embedders used for embedding chunks. Contains entries: 1) 'mem' for an in-memory embedder, and 2) 'disk' for an embedder that saves results in blocks to disk. + retro_gpt_chunk_datasets (RetroGPTChunkDatasets): GPT datasets for 'train', 'valid', and 'test'. + retro_tokenizers (RetroTokenizers): GPT ('gpt') and Bert ('bert') tokenizers. + """ + + # Basic. + retro_project_dir: str = None + retro_tasks: str = 'build' + retro_task_validate: float = None + retro_block_size: int = 100000 + retro_doc_block_size: int = 100000 + + # GPT. + retro_gpt_seed: int = 1234 + retro_gpt_data_path: list = None # basic list here, for parsing purposes + retro_gpt_data_cache_path: str = None + retro_gpt_split: str = '969,30,1' + retro_gpt_train_samples: int = None + retro_gpt_eval_interval: int = None + retro_gpt_eval_iters: int = None + retro_gpt_tokenizer_type: str = None + retro_gpt_tokenizer_model: str = None + retro_gpt_vocab_file: str = None + retro_gpt_merge_file: str = None + retro_gpt_seq_length: int = None + retro_gpt_global_batch_size: int = None + retro_gpt_chunk_length: int = 64 + + # Bert. + retro_bert_tokenizer_type: str = None + retro_bert_vocab_file: str = None + retro_bert_batch_size: int = 128 + retro_bert_max_chunk_length: int = 256 + + # Index. + retro_index_type: str = 'faiss-par-add' + retro_index_str: str = None + retro_index_ntrain: int = None + retro_index_train_load_fraction: float = 1.0 + retro_index_add_load_fraction: float = 1.0 + retro_index_delete_training_embeddings: bool = True + retro_index_delete_added_codes: bool = True + + # Query. + retro_query_ef_search: int = 256 + retro_query_nprobe: int = 65536 + retro_query_num_neighbors_query: int = 200 + retro_query_num_neighbors_save: int = 20 + + # Tools. + retro_bert_embedders: RetroBertEmbedders = None + retro_gpt_chunk_datasets: RetroGPTChunkDatasets = None + retro_tokenizers: RetroTokenizers = None + + def __post_init__(self) -> None: + """Validate Retro config.""" + + # Validate required attributes. + assert self.retro_project_dir is not None + assert self.retro_tasks is not None + assert self.retro_gpt_data_path is not None or self.retro_gpt_data_cache_path is not None + assert self.retro_gpt_train_samples is not None + assert self.retro_gpt_eval_interval is not None + assert self.retro_gpt_eval_iters is not None + assert self.retro_gpt_tokenizer_type is not None + assert self.retro_gpt_tokenizer_model is not None or ( + self.retro_gpt_vocab_file is not None and self.retro_gpt_merge_file is not None + ) + assert self.retro_gpt_seq_length is not None + assert self.retro_gpt_global_batch_size is not None + assert self.retro_bert_tokenizer_type is not None + assert self.retro_bert_vocab_file is not None + assert self.retro_index_str is not None + assert self.retro_index_ntrain is not None + + # Split retro tasks. + self.retro_tasks = self.retro_tasks.split(",") diff --git a/megatron/core/datasets/retro/config/gpt_chunk_datasets.py b/megatron/core/datasets/retro/config/gpt_chunk_datasets.py new file mode 100644 index 0000000000..831b1d812b --- /dev/null +++ b/megatron/core/datasets/retro/config/gpt_chunk_datasets.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Container dataclass for GPT chunk datasets (train, valid, and test).""" + +from dataclasses import dataclass + + +@dataclass +class RetroGPTChunkDatasets: + """Container dataclass for GPT chunk datasets.""" + + # Each dict contains 'dataset', 'neighbor_dir', and 'num_active_chunks'. + train: dict = None + valid: dict = None + test: dict = None diff --git a/megatron/core/datasets/retro/config/tokenizers.py b/megatron/core/datasets/retro/config/tokenizers.py new file mode 100644 index 0000000000..2e731c83b9 --- /dev/null +++ b/megatron/core/datasets/retro/config/tokenizers.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Container class for GPT and Bert tokenizers.""" + +from dataclasses import dataclass + +from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer + + +@dataclass +class RetroTokenizers: + """Container class for GPT and Bert tokenizers.""" + + gpt: MegatronTokenizer = None + bert: MegatronTokenizer = None diff --git a/megatron/core/datasets/retro/db/__init__.py b/megatron/core/datasets/retro/db/__init__.py new file mode 100644 index 0000000000..f1f460b3b0 --- /dev/null +++ b/megatron/core/datasets/retro/db/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Exports: + + - build_db: Build a chunk database from a list of indexed datasets. +""" + +from .build import build_db diff --git a/megatron/core/datasets/retro/db/build.py b/megatron/core/datasets/retro/db/build.py new file mode 100644 index 0000000000..44b9038230 --- /dev/null +++ b/megatron/core/datasets/retro/db/build.py @@ -0,0 +1,633 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Build a chunk database from a list of indexed datasets. + +Building a chunk database consists of. + + - Breaking each document of each indexed dataset into consecutive + retro_gpt_chunk_length chunks. + - Re-tokenize each chunk into Bert, and discard any chunks with empty Bert + tokens. + - Save chunk offsets to disk for each indexed dataset. +""" + +import glob +import os +import types +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Dict, List, Tuple + +import numpy as np +import torch +from tqdm import tqdm + +from megatron.core.datasets.indexed_dataset import IndexedDataset +from megatron.core.datasets.retro.config import RetroPreprocessingConfig +from megatron.core.datasets.retro.external_libs import h5py +from megatron.core.datasets.retro.utils import ( + extract_data_config, + get_blocks_by_rank, + log_retro_rank_0, + retro_makedir, +) + +from .utils import ( + get_indexed_dataset_infos, + get_indexed_dataset_infos_path, + get_individual_chunk_db, + get_individual_db_dir, + get_individual_db_paths, + get_individual_doc_offsets, + get_merged_db_path_map, + init_indexed_dataset_infos, + load_indexed_datasets, + save_indexed_dataset_infos, +) + + +def build_partial_db( + config: types.SimpleNamespace, + dataset_idx: int, + n_datasets: int, + indexed_dataset: IndexedDataset, + block_id: int, + n_blocks: int, + block: dict, + proc_id: int, + n_procs: int, +) -> Tuple[int, list, list, dict]: + """Process a document index range of the indexed dataset. + + The chunk database is built in parallel blocks, since de-tokenizing & + re-tokenizing for Bert-length computation is expensive. This method + iterates each document and extracts sequential 'chunk-length' sequences + from each document. + + Args: + config (types.SimpleNamespace): Subset of Retro config, containing 'chunk_length', 'gpt_eod', 'gpt_detokenize', 'bert_tokenize', and 'task_validate'. + dataset_idx (int): Index of this dataset out of all blended datasets. + n_datasets (int): Total number of blended datasets. + indexed_dataset (IndexedDataset): Indexed dataset to be chunked. + block_id (int): Block index out of all blocks to be processed. + n_blocks (int): Total number of blocks to be processed. + block (dict): Range information such as start/end points for chunking idnexed dataset. + proc_id (int): Process ID for tracking parallel process order. + n_procs (int): Total number of parallel processes. + + Returns: + A tuple containing: + + - Process ID. + - List of valid chunks. + - List of invalid chunks (i.e., chunks that converted to empty Bert embeddings.). + - Dict mapping document ID to number of valid chunks. + """ + + # Document start/end indexes. + doc_range = block["range"] + n_docs = doc_range[1] - doc_range[0] + n_docs_per_proc = int(np.ceil(n_docs / n_procs)) + doc_start_id = doc_range[0] + proc_id * n_docs_per_proc + doc_end_id = min(doc_range[1], doc_start_id + n_docs_per_proc) + + # Print progress. + progress_proc_ids = set(range(n_procs)) if torch.distributed.get_rank() == 0 else set() + if proc_id in progress_proc_ids: + log_retro_rank_0( + " > building partial chunk db, proc %d / %d, docs %d:%d / %d." + % (proc_id, n_procs, doc_start_id, doc_end_id, n_docs) + ) + + # Progress bars (snapshot of overall progress). + doc_id_iter = range(doc_start_id, doc_end_id) + pbar = ( + tqdm(doc_id_iter, "parse doc chunks", miniters=len(doc_id_iter) // 20) + if proc_id in progress_proc_ids + else doc_id_iter + ) + + # Iterate documents & parse chunks. + chunk_db_valid: List[Tuple] = [] + chunk_db_invalid: List[Tuple] = [] + doc_size_map = {} + for doc_id in pbar: + + # Progress description. + try: + pbar.set_description( + "%sds %d / %d, block %d / %d, proc %d / %d." + % ( + "" if config.task_validate is None else "[validate] ", + dataset_idx, + n_datasets, + block_id, + n_blocks, + proc_id, + n_procs, + ) + ) + except Exception: + pass + + # Remove EOD token. + doc = indexed_dataset.get(doc_id) + if doc[-1].item() == config.gpt_eod: + doc = doc[:-1] + doc_len = len(doc) + + # Chunk start/end indexes. + chunk_start_idxs = list(range(0, doc_len, config.chunk_length)) + chunk_end_idxs = [min(doc_len, s + config.chunk_length) for s in chunk_start_idxs] + + # Re-tokenize each chunk to Bert/Wordpiece (empty bert -> 'invalid'). + doc_size_map[doc_id] = 0 + for i, chunk_start_idx in enumerate(chunk_start_idxs): + + # Re-tokenize. + chunk_end_idx = chunk_end_idxs[i] + gpt_token_ids = indexed_dataset.get( + idx=doc_id, offset=chunk_start_idx, length=chunk_end_idx - chunk_start_idx + ) + text = config.gpt_detokenize(gpt_token_ids.tolist()) + bert_token_ids = config.bert_tokenize(text) + + # 'Valid' for non-empty Bert chunks; 'invalid' otherwise. + if len(bert_token_ids) == 0: + _chunk_db = chunk_db_invalid + else: + _chunk_db = chunk_db_valid + doc_size_map[doc_id] += 1 + _chunk_db.append((doc_id, chunk_start_idx, chunk_end_idx, len(bert_token_ids))) + + return proc_id, chunk_db_valid, chunk_db_invalid, doc_size_map + + +def build_block_db( + config: RetroPreprocessingConfig, + dataset_idx: int, + n_datasets: int, + indexed_dataset: IndexedDataset, + n_procs: int, + executor: ProcessPoolExecutor, + n_missing_blocks: int, + block_idx: int, + block: dict, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Split each document within block into consecutive retro_gpt_chunk_length size chunks. + + Args: + config (RetroPreprocessingConfig): For DB building, we make use of attributes 'chunk_length', 'gpt_eod', 'gpt_detokenize', 'bert_tokenize', and 'task_validate'. + dataset_idx (int): Index of this dataset out of all blended datasets. + n_datasets (int): Total number of blended datasets. + indexed_dataset (IndexedDataset): Indexed dataset to be chunked. + n_procs (int): Total number of parallel processes. + executor (ProcessPoolExecutor): Executor for launching parallel processes. + n_missing_blocks (int): Total number of blocks to be processed. + block_idx (int): Block index out of all blocks to be processed. + block (dict): Range information such as start/end points for chunking idnexed dataset. + + Returns: + A tuple containing: + + - List of valid chunks. + - List of invalid chunks (i.e., chunks that converted to empty Bert embeddings.). + - Dict mapping document ID to number of valid chunks. + """ + + # Build partial dbs. + log_retro_rank_0(' > build partial dbs.') + futures = [] + for proc_id in range(n_procs): # not true process id + futures.append( + executor.submit( + build_partial_db, + types.SimpleNamespace( + chunk_length=config.retro_gpt_chunk_length, + gpt_eod=config.retro_tokenizers.gpt.eod, + gpt_detokenize=config.retro_tokenizers.gpt.detokenize, + bert_tokenize=config.retro_tokenizers.bert.tokenize, + task_validate=config.retro_task_validate, + ), + dataset_idx, + n_datasets, + indexed_dataset, + block_idx, + n_missing_blocks, + block, + proc_id, + n_procs, + ) + ) + partial_chunk_dbs = [] + for future in as_completed(futures): + partial_chunk_dbs.append(future.result()) + + # Concatenate chunks. + partial_chunk_dbs.sort(key=lambda item: item[0]) # sort by proc_id + chunk_db_valid = [ + item for partial_chunk_db in partial_chunk_dbs for item in partial_chunk_db[1] + ] + chunk_db_invalid = [ + item for partial_chunk_db in partial_chunk_dbs for item in partial_chunk_db[2] + ] + + # Convert to numpy. + log_retro_rank_0(' > converting chunk db to numpy.') + chunk_db_valid = np.array(chunk_db_valid, dtype="uint32") + chunk_db_invalid = np.array(chunk_db_invalid, dtype="uint32") + + # Document offsets. + doc_sizes = [ + (d, s) for partial_chunk_db in partial_chunk_dbs for d, s in partial_chunk_db[3].items() + ] + doc_sizes.sort(key=lambda item: item[0]) + doc_offsets = np.cumsum([item[1] for item in doc_sizes]).astype("uint64") + doc_offsets = np.stack( + (np.array([item[0] for item in doc_sizes], dtype="uint64"), doc_offsets), axis=1 + ) + + return chunk_db_valid, chunk_db_invalid, doc_offsets + + +def save_block_db( + block: dict, chunk_db_valid: np.ndarray, chunk_db_invalid: np.ndarray, doc_offsets: np.ndarray +) -> None: + """Save block of chunked tokens to disk. These blocks are later used for + training and adding to the vector index. + + Args: + block (dict): Range information such as start/end points for chunking idnexed dataset. + chunk_db_valid (np.ndarray): Array of valid chunk indexes. + chunk_db_invalid (np.ndarray): Array of invalid chunk indexes. + doc_offsets (np.ndarray): Array of document offsets by chunks. + """ + log_retro_rank_0(" > saving individual db.") + with h5py.File(block["path"], "w") as f: + dset = f.create_dataset("chunks_valid", data=chunk_db_valid) + dset = f.create_dataset("chunks_invalid", data=chunk_db_invalid) + dset = f.create_dataset("doc_offsets", data=doc_offsets) + + +def build_individual_db( + config: RetroPreprocessingConfig, dataset_idx: int, n_datasets: int, dataset_info: dict +) -> None: + """Process a single indexed dataset & extract chunks. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + dataset_idx (int): Dataset index within blended dataset. + n_datasets (int): Total number of datasets within blended dataset. + dataset_info (dict): Metadata for dataset (see `save_indexed_dataset_infos()` in `utils.py` for more detail). + """ + + # Make directory. + db_dir = get_individual_db_dir(config.retro_project_dir, dataset_info["prefix"]) + retro_makedir(config, db_dir) + + # Indexed dataset. + indexed_dataset = dataset_info["dataset"] + + # Missing DB blocks (split by documents). + blocks = get_blocks_by_rank( + db_dir, + len(indexed_dataset), + config.retro_doc_block_size, + validate=lambda f: f["chunks_valid"].shape == (0,) or f["chunks_valid"].shape[1] == 4, + sample=config.retro_task_validate, + ) + if config.retro_task_validate is None: + active_blocks = blocks.missing + else: + assert blocks.n_missing_world == 0 + active_blocks = blocks.existing + + # Prevent missing-path-write race condition. + torch.distributed.barrier() + + # Nothing to do? + if config.retro_task_validate is None and not active_blocks: + return + + # Num processes. + if blocks.n_missing_world == 1: + n_procs = 128 + elif blocks.n_missing_world <= 2: + n_procs = 64 + elif blocks.n_missing_world <= 4: + n_procs = 32 + elif blocks.n_missing_world <= 8: + n_procs = 16 + else: + n_procs = 8 + + # Process documents in parallel. + with ProcessPoolExecutor(max_workers=n_procs) as executor: + for block_idx, block in enumerate(active_blocks): + + if block is not None: + + # Build block DB. + chunk_db_valid, chunk_db_invalid, doc_offsets = build_block_db( + config=config, + dataset_idx=dataset_idx, + n_datasets=n_datasets, + indexed_dataset=indexed_dataset, + n_procs=n_procs, + executor=executor, + n_missing_blocks=len(active_blocks), + block_idx=block_idx, + block=block, + ) + + if config.retro_task_validate is None: + # Save block DB. + save_block_db( + block=block, + chunk_db_valid=chunk_db_valid, + chunk_db_invalid=chunk_db_invalid, + doc_offsets=doc_offsets, + ) + + else: + + # Load existing block DB. + with h5py.File(block["path"]) as f: + existing_chunks_valid = np.copy(f["chunks_valid"]) + existing_chunks_invalid = np.copy(f["chunks_invalid"]) + existing_doc_offsets = np.copy(f["doc_offsets"]) + + # Check equality. + log_retro_rank_0(" > validate.") + assert np.array_equal(existing_chunks_valid, chunk_db_valid) + assert np.array_equal(existing_chunks_invalid, chunk_db_invalid) + assert np.array_equal(existing_doc_offsets, doc_offsets) + + # Wait for all ranks to finish block. + log_retro_rank_0(" > waiting for all ranks to finish block.") + torch.distributed.barrier() + + log_retro_rank_0(" > finished saving individual db.") + + +def build_individual_dbs( + config: RetroPreprocessingConfig, indexed_dataset_infos: List[Dict] +) -> None: + """Iterate each indexed dataset & process its chunks. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset. + """ + + # Build individual DBs. + log_retro_rank_0(" > build individual chunk dbs.") + for ds_idx, ds_info in enumerate(indexed_dataset_infos): + + # Progress. + log_retro_rank_0( + " > building individual db, dataset %d / %d ... '%s'." + % (ds_idx, len(indexed_dataset_infos), ds_info["prefix"]) + ) + + # Process single dataset. + build_individual_db(config, ds_idx, len(indexed_dataset_infos), ds_info) + + +def update_chunk_counts( + config: RetroPreprocessingConfig, indexed_dataset_infos: List[Dict] +) -> None: + """Set n_chunks_train & n_chunks sampled for each individual DB. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset (i.e., 'prefix', 'ratio', 'n_chunks', etc.). + """ + + if torch.distributed.get_rank() != 0: + return + + # Data ratio sum (for setting index training chunks). + data_ratio_sum = sum([d["ratio"] for d in indexed_dataset_infos]) + + # Training split size (split at document level). + train_fraction = float(extract_data_config(config).split.split(",")[0]) / 100 + assert train_fraction > 0 and train_fraction <= 1 + + # Set n_chunks (including n_chunks_sampled for unambiguity). + log_retro_rank_0(" > compute n_chunks.") + for ds_index, ds_info in enumerate(indexed_dataset_infos): + + db_paths = get_individual_db_paths(config.retro_project_dir, ds_info["prefix"]) + + # Update counts. + ds_info["n_docs"] = len(ds_info["dataset"].document_indices) - 1 + ds_info["n_docs_train"] = int(train_fraction * ds_info["n_docs"]) + ds_info["n_chunks"] = 0 # previously, 'n_chunks_valid' + ds_info["n_chunks_train"] = 0 + ds_info["n_chunks_invalid"] = 0 + for db_path in tqdm( + db_paths, "%d/%d, %s" % (ds_index, len(indexed_dataset_infos), ds_info["prefix"]) + ): + with h5py.File(db_path, "r") as f: + ds_info["n_chunks"] += len(f["chunks_valid"]) + ds_info["n_chunks_invalid"] += len(f["chunks_invalid"]) + ds_info["n_chunks_train"] += ( + (np.copy(f["chunks_valid"][:, 0]) < ds_info["n_docs_train"]).sum().item() + ) + + ds_info["n_chunks_sampled"] = int( + config.retro_index_ntrain * ds_info["ratio"] / data_ratio_sum + ) + + # Verify counts. + assert ds_info["n_chunks_train"] <= ds_info["n_chunks"], "n_train (%d) > n_total (%d)." % ( + ds_info["n_chunks_train"], + ds_info["n_chunks"], + ) + assert ( + ds_info["n_chunks_sampled"] <= ds_info["n_chunks_train"] + ), "n_sampled (%d) > n_train (%d)." % ( + ds_info["n_chunks_sampled"], + ds_info["n_chunks_train"], + ) + + +def merge_dbs(project_dir: str, indexed_dataset_infos: List[Dict], db_type: str) -> None: + """Merge individual DBs into single DB. + + Args: + project_dir (str): Retro project dir. + indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset (i.e., 'prefix', 'ratio', 'n_chunks', etc.). + db_type (str): DB type (e.g., 'sampled', 'train', or 'valid'). + """ + + if torch.distributed.get_rank() != 0: + return + + log_retro_rank_0(" > build %s chunk db." % db_type) + + # Count chunks. + if db_type == "sampled": + n_chunks_key = "n_chunks_sampled" + n_docs_key = None + elif db_type == "train": + n_chunks_key = "n_chunks_train" + n_docs_key = "n_docs_train" + elif db_type == "valid": + n_docs_key = None + else: + raise Exception("handle db_type '%s'." % db_type) + + if db_type == "valid": + n_chunks = sum(m["n_chunks"] - m["n_chunks_train"] for m in indexed_dataset_infos) + else: + n_chunks = sum(m[n_chunks_key] for m in indexed_dataset_infos) + n_docs = None if n_docs_key is None else sum(m[n_docs_key] for m in indexed_dataset_infos) + + # DB path. + db_path = get_merged_db_path_map(project_dir)[db_type] + + # Delete existing chunk db if incorrect size. + if os.path.exists(db_path): + + try: + + f = h5py.File(db_path) + n_alloc = len(f["chunks"]) # total allocated + n_written = f["n_written"][0].item() # total written + f.close() + + if n_chunks != n_alloc or n_chunks != n_written: + os.remove(db_path) + + except Exception as e: + if isinstance(e, OSError): + os.remove(db_path) + elif isinstance(e, KeyError): + f.close() + os.remove(db_path) + else: + raise e + + # Build merged chunk db. + if not os.path.exists(db_path): + + os.makedirs(os.path.dirname(db_path), exist_ok=True) + f = h5py.File(db_path, "w") + + # Initialize output arrays. + merged_chunk_db: np.ndarray = f.create_dataset("chunks", (n_chunks, 5), dtype="uint32") + merged_doc_offsets: np.ndarray = ( + None + if n_docs_key is None + else f.create_dataset("doc_offsets", (n_docs, 3), dtype="uint64") + ) + n_written = f.create_dataset("n_written", (1,), dtype="uint64") + n_written[0] = 0 + + # Iterate indexed datasets & collect chunks. + chunk_start_index = 0 + doc_start_index = 0 + doc_start_offset = 0 + for ds_idx, ds_info in enumerate(indexed_dataset_infos): + log_retro_rank_0( + " > merging dbs; '%s', dataset %d / %d ... '%s'." + % (db_type, ds_idx, len(indexed_dataset_infos), ds_info["prefix"]) + ) + individual_chunk_db: np.ndarray = get_individual_chunk_db(project_dir, ds_idx, ds_info) + individual_doc_offsets: np.ndarray = ( + None + if n_docs_key is None + else get_individual_doc_offsets(project_dir, ds_idx, ds_info) + ) + + if db_type == "valid": + individual_chunk_db = individual_chunk_db[ds_info["n_chunks_train"] :] + if n_docs_key is None: + individual_doc_offsets = None + else: + train_doc_offset = individual_doc_offsets[ds_info["n_docs_train"] - 1, 2] + individual_doc_offsets = np.copy( + individual_doc_offsets[ds_info["n_docs_train"] :] + ) + individual_doc_offsets[:, 2] -= train_doc_offset + + log_retro_rank_0("~~~") + log_retro_rank_0(individual_doc_offsets) + log_retro_rank_0(train_doc_offset) + raise Exception("test me.") + else: + individual_chunk_db = individual_chunk_db[: ds_info[n_chunks_key]] + individual_doc_offsets = ( + None + if n_docs_key is None + else np.copy(individual_doc_offsets[: ds_info[n_docs_key]]) + ) + + merged_chunk_db[chunk_start_index : chunk_start_index + len(individual_chunk_db)] = ( + individual_chunk_db + ) + chunk_start_index += len(individual_chunk_db) + n_written[0] = chunk_start_index + if n_docs_key is not None: + individual_doc_offsets[:, 2] += doc_start_offset + doc_end_index = doc_start_index + individual_doc_offsets.shape[0] + merged_doc_offsets[doc_start_index:doc_end_index] = individual_doc_offsets + doc_start_index = doc_end_index + doc_start_offset = individual_doc_offsets[-1, 2].item() + + f.close() + + +def build_merged_dbs(project_dir: str, indexed_dataset_infos: List[Dict]) -> None: + """Merge individual dataset components into single database. + + This method merges databases for DB types: + - 'sampled': used for training the vector index. + - 'train': used for adding to the trained vector index. + - 'valid': can be used for validating/testing the vector index. + + Args: + project_dir (str): Retro project dir. + indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset (i.e., 'prefix', 'ratio', 'n_chunks', etc.). + """ + merge_dbs(project_dir, indexed_dataset_infos, "sampled") + merge_dbs(project_dir, indexed_dataset_infos, "train") + merge_dbs(project_dir, indexed_dataset_infos, "valid") + + +def build_db(config: RetroPreprocessingConfig) -> None: + """Extract token chunks from each indexed dataset. + + Iterate each document of each indexed dataset, extract that document's chunks, and save to a 'DB' (hdf5 file). + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + project_dir = config.retro_project_dir + + # Indexed dataset info. + if config.retro_task_validate is None: + indexed_dataset_infos = init_indexed_dataset_infos(config) + else: + indexed_dataset_infos = get_indexed_dataset_infos(config.retro_project_dir) + # Build individual dbs. + build_individual_dbs(config, indexed_dataset_infos) + + # If validating, return here. + if config.retro_task_validate is not None: + return + + # Single-process going forward. + if torch.distributed.get_rank() != 0: + return + + # Update n_chunks & save indexed dataset infos. + if not os.path.exists(get_indexed_dataset_infos_path(project_dir)): + update_chunk_counts(config, indexed_dataset_infos) + save_indexed_dataset_infos(project_dir, indexed_dataset_infos) + indexed_dataset_infos = get_indexed_dataset_infos(project_dir) + + # Builded merged dbs. + build_merged_dbs(project_dir, indexed_dataset_infos) diff --git a/megatron/core/datasets/retro/db/dataset.py b/megatron/core/datasets/retro/db/dataset.py new file mode 100644 index 0000000000..f9053622ab --- /dev/null +++ b/megatron/core/datasets/retro/db/dataset.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""A DBDataset is for iterating the chunks of the chunk database. + +This dataset is used for both training a vector index, and adding vectors to a +trained index. +""" + +from typing import List + +import numpy as np +import torch +from tqdm import tqdm + +from megatron.core.datasets.indexed_dataset import IndexedDataset + + +class DBDataset(torch.utils.data.Dataset): + """Dataset for iterating chunks. + + Args: + db_path (str): Path of HDF5-format chunk database. + indexed_datasets (List[IndexedDataset]): Indexed datasets used to build database. + chunks (np.ndarray): Array of chunk indexes, for indexing into indexed datasets. Format [dataset_idx, doc_id, start_idx, end_idx, bert_length]. + chunk_length (int): Max GPT chunk length (e.g., 64). + eod_token_id (int): EOD token ID. + """ + + def __init__( + self, + db_path: str, + indexed_datasets: List[IndexedDataset], + chunks: np.ndarray, + chunk_length: int, + eod_token_id: int, + ): + + assert chunks.shape[1] == 5, ( + "expected 5 columns (dataset_idx, " + "doc_idx, token_start_idx, token_end_idx, bert_chunk_length); " + "found %d columns." % chunks.shape[1] + ) + + self.db_path = db_path + self.indexed_datasets = indexed_datasets + self.chunks = chunks + self.doc_chunk_map = None + + self.max_chunk_length = chunk_length + self.eod_token_id = eod_token_id + + def __len__(self) -> int: + """Length of DB dataset. + + Returns: + Number of chunks contained in the dataset. + """ + return self.chunks.shape[0] + + def __getitem__(self, chunk_id: int) -> dict: + """DB dataset sample. + + Args: + chunk_id (int): Index of chunk within dataset. + + Returns: + A dict containing: + - 'doc_id': Document index within indexed dataset. + - 'text': GPT token IDs. + """ + + # Chunk start/end indexes. + indexed_dataset_id, doc_id, token_start_idx, token_end_idx, _ = [ + value.item() for value in self.chunks[chunk_id] + ] + chunk_length = token_end_idx - token_start_idx + indexed_dataset = self.indexed_datasets[indexed_dataset_id] + + # Chunk token ids. + token_ids = indexed_dataset.get(doc_id, offset=token_start_idx, length=chunk_length) + + # Extend chunks to max_chunk_length by padding with EOD tokens. + if chunk_length != self.max_chunk_length: + assert chunk_length < self.max_chunk_length, "invalid chunk len." + token_ids = token_ids.tolist() + token_ids += [self.eod_token_id] * (self.max_chunk_length - chunk_length) + + return {"doc_id": doc_id, "text": np.array(token_ids, dtype=np.int64)} + + def load_doc_tuples(self) -> None: + """Load the dataset & document ids. + + Load the dataset id & document id of each chunk in the database, to + be used for causality filtering during querying. + """ + self.doc_tuples = np.zeros(shape=(len(self), 2), dtype="uint32") + block_size = int(1e6) + for start_idx in tqdm( + range(0, len(self), block_size), + "load doc tuples", + miniters=(len(self) // block_size) // 10, + disable=torch.distributed.get_rank() != 0, + ): + end_idx = min(len(self), start_idx + block_size) + self.doc_tuples[start_idx:end_idx] = self.chunks[start_idx:end_idx, :2] diff --git a/megatron/core/datasets/retro/db/utils.py b/megatron/core/datasets/retro/db/utils.py new file mode 100644 index 0000000000..e8578a09d5 --- /dev/null +++ b/megatron/core/datasets/retro/db/utils.py @@ -0,0 +1,367 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Utilities for building a chunk database.""" + +import glob +import json +import os +from typing import Dict, List, Optional + +import numpy as np + +from megatron.core.datasets.indexed_dataset import IndexedDataset +from megatron.core.datasets.retro.config import RetroPreprocessingConfig +from megatron.core.datasets.retro.external_libs import h5py +from megatron.core.models.retro.utils import get_gpt_data_dir + +from .dataset import DBDataset + + +def get_db_dir(project_dir: str) -> str: + """Sub-directory for DB data. + + Args: + project_dir (str): Path to Retro project dir. + + Returns: + Path of the DB sub-directory within the project. + """ + return os.path.join(project_dir, "db") + + +def init_indexed_dataset_infos(config: RetroPreprocessingConfig) -> List[Dict]: + """Gather meta-info about each indexed dataset. + + The returned info array allows for easy access to the configuration, and + helps remove ambiguity. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + List of processing metadata for each dataset, including: + - ratio: Data split weight. + - prefix: Relative path to dataset under DB sub-directory. + """ + + data_dir = get_gpt_data_dir(config.retro_project_dir) + data_blend: List[str] = config.retro_gpt_data_path + assert len(data_blend) % 2 == 0, "currently, only blended dataset is supported." + + # Dataset infos. + infos = [] + for i in range(0, len(data_blend), 2): + ratio = float(data_blend[i]) + prefix = data_blend[i + 1] + path = os.path.join(data_dir, prefix + ".bin") + assert os.path.exists(path), "couldn't find '%s'." % path + infos.append({"ratio": ratio, "prefix": prefix}) + + # Load indexed datasets. + load_indexed_datasets(config.retro_project_dir, infos) + + return infos + + +def get_indexed_dataset_infos_path(project_dir: str) -> str: + """Path to indexed dataset meta-infos. + + Args: + project_dir (str): Path to Retro project dir. + + Returns: + Path to the `indexed_dataset_infos.json` file. + """ + return os.path.join(get_db_dir(project_dir), "indexed_dataset_infos.json") + + +def save_indexed_dataset_infos(project_dir: str, indexed_dataset_infos: List[Dict]) -> None: + """Save dataset order & meta-info. + + Args: + project_dir (str): Path to Retro project dir. + indexed_dataset_infos (List[Dict]): List of metadata for each dataset, with each entry containing: + + - ratio: Data split weight. + - prefix: Relative path to dataset under DB sub-directory. + - n_docs: Number of documents. + - n_docs_train: Number of documents used for pretraining. + - n_chunks: Number of valid chunks. + - n_chunks_train: Number of valid chunks used for pretraining. + - n_chunks_invalid: Number of invalid chunks. + - n_chunks_sampled: Number of valid chunks used for vector index training. + """ + + # Remove 'dataset' field. + clean_infos = [] + for info in indexed_dataset_infos: + info = dict(info) + del info["dataset"] + clean_infos.append(info) + + # Save. + with open(get_indexed_dataset_infos_path(project_dir), "w") as f: + json.dump(clean_infos, f, indent=4) + + +def load_indexed_datasets(project_dir: str, indexed_dataset_infos: List[Dict]) -> None: + """Loaded indexed datasets into memory-mapped datasets. + + Args: + project_dir (str): Path to Retro project dir. + indexed_dataset_infos (List[Dict]): List of metadata for each dataset (see `save_indexed_dataset_infos()` for more details. + """ + data_dir = get_gpt_data_dir(project_dir) + for info in indexed_dataset_infos: + info["dataset"] = IndexedDataset(os.path.join(data_dir, info["prefix"]), mmap=True) + + +def get_indexed_dataset_infos(project_dir: str) -> List[Dict]: + """Load indexed dataset meta-infos. + + Args: + project_dir (str): Path to Retro project dir. + + Returns: + List of metadata for each dataset (see `save_indexed_dataset_infos()` for more details. + """ + + # Load json. + path = get_indexed_dataset_infos_path(project_dir) + with open(path) as f: + infos = json.load(f) + + # Load indexed datasets. + load_indexed_datasets(project_dir, infos) + + return infos + + +def get_individual_db_dir(project_dir: str, prefix: str) -> str: + """Individual DB's directory. + + Args: + project_dir (str): Path to Retro project dir. + prefix (str): Unique relative path to dataset within project dir. + + Returns: + Path to the given datasets's chunk database. + """ + return os.path.join(get_db_dir(project_dir), "individual", prefix) + + +def get_individual_db_paths(project_dir: str, prefix: str) -> List[str]: + """Get paths of all database blocks of an individual dataset. + + Args: + project_dir (str): Path to Retro project dir. + prefix (str): Unique relative path to dataset within project dir. + + Returns: + Paths to each HDF5 chunk database files that comprises this datasets full chunk database. + """ + return sorted(glob.glob(get_individual_db_dir(project_dir, prefix) + "/*hdf5")) + + +def get_individual_chunk_db(project_dir: str, ds_id: int, ds_info: dict) -> np.ndarray: + """Load individual dataset's chunk DB. + + Args: + project_dir (str): Path to Retro project dir. + ds_id (int): Index of dataset within blended dataset. + ds_info (dict): Preprocessing metadata for dataset (see `save_indexed_dataset_infos()` for more detail). + + Returns: + Array of chunk start/end indexes for this dataset, where the chunk indexes can be used for indexing into the corresponding indexed dataset. + """ + paths = get_individual_db_paths(project_dir, ds_info["prefix"]) + # *Note*: convert to dataset, rather than copying to memory. + db = np.zeros((ds_info["n_chunks"], 5), dtype="uint32") + db[:, 0] = ds_id + start_idx = 0 + for path in paths: + f = h5py.File(path, "r") + n_chunks_current = f["chunks_valid"].shape[0] + db[start_idx : (start_idx + n_chunks_current), 1:] = f["chunks_valid"] + start_idx += n_chunks_current + f.close() + + assert start_idx == ds_info["n_chunks"] + + return db + + +def get_individual_doc_offsets(project_dir: str, ds_id: int, ds_info: dict) -> np.ndarray: + """Load individual dataset's document offsets. + + Args: + project_dir (str): Path to Retro project dir. + ds_id (int): Index of dataset within blended dataset. + ds_info (dict): Preprocessing metadata for dataset (see `save_indexed_dataset_infos()` for more detail). + + Returns: + Array of document offsets by chunk index for this dataset. + """ + paths = get_individual_db_paths(project_dir, ds_info["prefix"]) + # *Note*: convert to dataset, rather than copying to memory. + doc_offsets = np.zeros((ds_info["n_docs"], 3), dtype="uint64") + doc_offsets[:, 0] = ds_id + start_idx = 0 + start_offset = 0 + for path in paths: + with h5py.File(path) as f: + current_doc_offsets = np.copy(f["doc_offsets"]) + current_doc_offsets[:, 1] += start_offset + current_ndocs = current_doc_offsets.shape[0] + doc_offsets[start_idx : (start_idx + current_ndocs), 1:] = current_doc_offsets + start_idx += current_ndocs + start_offset = current_doc_offsets[-1, 1].item() + + return doc_offsets + + +def get_merged_db_path_map(project_dir: str) -> dict: + """Paths to merged datasets. + + Args: + project_dir (str): Path to Retro project dir. + + Returns: + A dict of chunk databases, one for each of: + - sampled: Chunks used for training the vector index. + - train: Chunks used for pretraining 'train' dataset. + - valid: Chunks used for pretraining 'valid' dataset. + """ + base_dir = get_db_dir(project_dir) + return { + "sampled": os.path.join(base_dir, "merged", "sampled.hdf5"), + "train": os.path.join(base_dir, "merged", "train.hdf5"), + "valid": os.path.join(base_dir, "merged", "valid.hdf5"), + } + + +def get_merged_dataset( + project_dir: str, + chunk_length: int, + eod_token_id: int, + db_type: str, + indexed_dataset_infos: Optional[List[Dict]] = None, +) -> DBDataset: + """Get merged dataset. + + Args: + project_dir (str): Path to Retro project dir. + chunk_length (int): GPT chunk length (e.g., 64). + eod_token_id (int): EOD token ID. + db_type (str): DB type (e.g., 'sampled', 'train', or 'valid'). + indexed_dataset_infos (Optional[List[Dict]]): Optionally, pre-loaded list of dataset metadata (see `save_indexed_dataset_infos()` for more detail). If not provided, the indexed dataset infos will be loaded from disk. + + Returns: + A DBDataset, which is a dataset that wraps the HDF5 chunk index array. + """ + + if not indexed_dataset_infos: + indexed_dataset_infos = get_indexed_dataset_infos(project_dir) + + # Load chunks. + db_path = get_merged_db_path_map(project_dir)[db_type] + f = h5py.File(db_path, "r") + chunks = f["chunks"] + + # DB dataset. + indexed_datasets = [info["dataset"] for info in indexed_dataset_infos] + dataset = DBDataset( + db_path=db_path, + indexed_datasets=indexed_datasets, + chunks=chunks, + chunk_length=chunk_length, + eod_token_id=eod_token_id, + ) + + return dataset + + +def get_merged_sampled_dataset( + project_dir: str, + chunk_length: int, + eod_token_id: int, + indexed_dataset_infos: Optional[List[Dict]] = None, +) -> DBDataset: + """Get sampled dataset (for training the vector index). + + Args: + project_dir (str): Path to Retro project dir. + chunk_length (int): GPT chunk length (e.g., 64). + eod_token_id (int): EOD token ID. + indexed_dataset_infos (Optional[List[Dict]]): Optionally, pre-loaded list of dataset metadata (see `save_indexed_dataset_infos()` for more detail). If not provided, the indexed dataset infos will be loaded from disk. + + Returns: + A DBDataset, which is a dataset that wraps the HDF5 chunk index array. + """ + return get_merged_dataset( + project_dir, chunk_length, eod_token_id, "sampled", indexed_dataset_infos + ) + + +def get_merged_train_dataset( + project_dir: str, + chunk_length: int, + eod_token_id: int, + indexed_dataset_infos: Optional[List[Dict]] = None, +) -> DBDataset: + """Get training dataset (for adding to the vector index). + + Args: + project_dir (str): Path to Retro project dir. + chunk_length (int): GPT chunk length (e.g., 64). + eod_token_id (int): EOD token ID. + indexed_dataset_infos (Optional[List[Dict]]): Optionally, pre-loaded list of dataset metadata (see `save_indexed_dataset_infos()` for more detail). If not provided, the indexed dataset infos will be loaded from disk. + + Returns: + A DBDataset, which is a dataset that wraps the HDF5 chunk index array. + """ + return get_merged_dataset( + project_dir, chunk_length, eod_token_id, "train", indexed_dataset_infos + ) + + +def get_merged_valid_dataset( + project_dir: str, + chunk_length: int, + eod_token_id: int, + indexed_dataset_infos: Optional[List[Dict]] = None, +) -> DBDataset: + """Get validation dataset (for testing the vector index). + + Args: + project_dir (str): Path to Retro project dir. + chunk_length (int): GPT chunk length (e.g., 64). + eod_token_id (int): EOD token ID. + indexed_dataset_infos (Optional[List[Dict]]): Optionally, pre-loaded list of dataset metadata (see `save_indexed_dataset_infos()` for more detail). If not provided, the indexed dataset infos will be loaded from disk. + + Returns: + A DBDataset, which is a dataset that wraps the HDF5 chunk index array. + """ + return get_merged_dataset( + project_dir, chunk_length, eod_token_id, "valid", indexed_dataset_infos + ) + + +def get_merged_datasets(project_dir: str, chunk_length: int, eod_token_id: int) -> dict: + """Get all merged datasets. + + Args: + project_dir (str): Path to Retro project dir. + chunk_length (int): GPT chunk length (e.g., 64). + eod_token_id (int): EOD token ID. + + Returns: + A dict mapping DB type ('sampled', 'train', or 'valid') to the corresponding DBDataset, which is a dataset that wraps the HDF5 chunk index array. + """ + fns = { + "sampled": get_merged_sampled_dataset, + "train": get_merged_train_dataset, + "valid": get_merged_valid_dataset, + } + datasets = {key: fn(project_dir, chunk_length, eod_token_id) for key, fn in fns.items()} + return datasets diff --git a/megatron/core/datasets/retro/external_libs.py b/megatron/core/datasets/retro/external_libs.py new file mode 100644 index 0000000000..c057eba25c --- /dev/null +++ b/megatron/core/datasets/retro/external_libs.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Required external libraries for Retro preprocessing.""" + +import importlib + +required_libs = ["faiss", "h5py", "transformers"] # for huggingface bert + +for lib in required_libs: + try: + globals()[lib] = importlib.import_module(lib) + except ImportError as e: + raise Exception( + f"Missing one or more packages required for Retro preprocessing: {required_libs}. Tried importing '{lib}'." + ) diff --git a/megatron/core/datasets/retro/index/__init__.py b/megatron/core/datasets/retro/index/__init__.py new file mode 100644 index 0000000000..d069f55f22 --- /dev/null +++ b/megatron/core/datasets/retro/index/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Exports: + + - train_index: Train an index on representative vectors. + - add_to_index: Add vectors to a trained index. + - build_index: Wrapper function that calls above two functions. +""" + +from .build import add_to_index, build_index, train_index diff --git a/megatron/core/datasets/retro/index/build.py b/megatron/core/datasets/retro/index/build.py new file mode 100644 index 0000000000..1f310d89c3 --- /dev/null +++ b/megatron/core/datasets/retro/index/build.py @@ -0,0 +1,313 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Construct an index. + +Constructing an index generally happens in two phases: + + - index.train(): Train an index on a representative set of vectors. + - index.add(): Add vectors to an index, to be available for retrieval. +""" + +import os +import shutil + +import numpy as np +import torch +from tqdm import tqdm + +from megatron.core.datasets.retro.config import RetroPreprocessingConfig +from megatron.core.datasets.retro.db.utils import ( + get_merged_sampled_dataset, + get_merged_train_dataset, +) +from megatron.core.datasets.retro.external_libs import h5py +from megatron.core.datasets.retro.utils import GPTToTextDataset + +from .factory import IndexFactory +from .utils import ( + get_training_data_block_dir, + get_training_data_block_paths, + get_training_data_merged_path, + get_training_data_root_dir, +) + +################################################## +# Train index. +################################################## + + +def get_empty_index_path(config: RetroPreprocessingConfig) -> str: + """Path of empty index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + Path to the empty (trained, but without added samples) vector index. + """ + index = IndexFactory.get_index(config.retro_index_type) + empty_index_path = index.get_empty_index_path(config) + return empty_index_path + + +def get_block_nload(block_path: str, load_fraction: float) -> int: + """Compute number of blocks to load. + + This is computed by multiplying the total number of available blocks with the + fraction of blocks to load. + + Args: + block_path (str): Path to HDF5 file containing block of data. File must contain key 'data'. + load_fraction (float): Fraction (0 < load_fraction <= 1) of block samples to load. + + Returns: + Number of block samples to load. + """ + with h5py.File(block_path) as fi: + return int(load_fraction * fi["data"].shape[0]) + + +def merge_embedding_blocks(config: RetroPreprocessingConfig) -> None: + """Merge individual embedding blocks into a single binary mmap file. + + The embeddings are initially stored in block-sized (e.g., ~100k embeddings per + block) HDF5 files. These individual block files must be merged into a single + file before training, to be based as a numpy mmap array to the index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + if torch.distributed.get_rank() != 0: + return + + # Get block, merged paths. + load_fraction = config.retro_index_train_load_fraction + block_paths = get_training_data_block_paths(config) + bin_path = get_training_data_merged_path(config) + + # Skip, if already built. + if os.path.exists(bin_path): + return + + # Merge blocks. + with open(bin_path, "wb") as fo: + byte_offset = 0 + for block_idx, block_path in enumerate( + tqdm( + block_paths, + "merge train embeddings", + miniters=len(block_paths) // 10, + disable=torch.distributed.get_rank() != 0, + ) + ): + with h5py.File(block_path) as fi: + + nload = get_block_nload(block_path, load_fraction) + block = np.array(fi["data"][:nload], copy=False) + + fo.write(block.tobytes()) + + byte_offset += block.size * block.itemsize + fo.seek(byte_offset) + + +def get_text_dataset_for_training(config: RetroPreprocessingConfig) -> GPTToTextDataset: + """Convert GPT token chunk dataset to a text dataset for passing to the + embedder. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + The text dataset consisting of tokens converted from sampled chunk database. + """ + gpt_dataset = get_merged_sampled_dataset( + project_dir=config.retro_project_dir, + chunk_length=config.retro_gpt_chunk_length, + eod_token_id=config.retro_tokenizers.gpt.eod, + ) + text_dataset = GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt) + return text_dataset + + +def embed_training_chunks(config: RetroPreprocessingConfig) -> None: + """Embed DB chunks. + + Store chunks in blocks on disk. These blocks will later be merged into + a single dataset for training the index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + merged_train_data_path = get_training_data_merged_path(config) + if os.path.exists(merged_train_data_path): + return + + # Get training text dataset. + text_dataset = get_text_dataset_for_training(config) + + # Embed dataset. + embedder = config.retro_bert_embedders.disk + embedder.embed_text_dataset("index", get_training_data_block_dir(config), text_dataset) + + # Merge embeddings. + merge_embedding_blocks(config) + + +def train_on_embeddings(config: RetroPreprocessingConfig) -> None: + """Train index on embedded DB chunks. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + index = IndexFactory.get_index(config.retro_index_type) + index.train(config) + + +def remove_embeddings(config: RetroPreprocessingConfig) -> None: + """Remove embeddings after training. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + torch.distributed.barrier() + if torch.distributed.get_rank() != 0: + return + empty_index_path = get_empty_index_path(config) + assert os.path.isfile(empty_index_path) + shutil.rmtree(get_training_data_root_dir(config), ignore_errors=True) + + +def _train_index(config: RetroPreprocessingConfig) -> None: + """Train index on DB chunks. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + # Check if trained index already exists. + if not os.path.isfile(get_empty_index_path(config)): + + # Embed training chunks. + embed_training_chunks(config) + + # Train index on embeddings. + train_on_embeddings(config) + + # Wait for (single-process) training to complete. + torch.distributed.barrier() + + # Remove embeddings. + if config.retro_index_delete_training_embeddings: + remove_embeddings(config) + + +def train_index(config: RetroPreprocessingConfig) -> None: + """Entry point for training the index. + + We select whether to train a new index, or validate an existing index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + # Train new index. + if config.retro_task_validate is None: + _train_index(config) + + # Validate existing trained index. + else: + from .validate import validate_training_embeddings + + validate_training_embeddings(config) + + +################################################## +# Add to index. +################################################## + + +def get_text_dataset_for_adding(config: RetroPreprocessingConfig) -> GPTToTextDataset: + """Convert GPT token chunk dataset to a text dataset for passing to the + embedder. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + The text dataset that consists of tokens converted from the 'train' chunk database. These are the chunks used for retrieval by the pretraining 'train' dataset. + """ + gpt_dataset = get_merged_train_dataset( + project_dir=config.retro_project_dir, + chunk_length=config.retro_gpt_chunk_length, + eod_token_id=config.retro_tokenizers.gpt.eod, + ) + text_dataset = GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt) + return text_dataset + + +def _add_to_index(config: RetroPreprocessingConfig) -> str: + """Add DB chunks to index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + Path to the populated index. + """ + + # Get index. + index = IndexFactory.get_index(config.retro_index_type) + + # Get text dataset. + text_dataset = get_text_dataset_for_adding(config) + + # Add to index. + output_index_path = index.add(config, text_dataset) + + return output_index_path + + +def add_to_index(config: RetroPreprocessingConfig) -> None: + """Entry point for adding to the index. + + We select whether to add to a new index, or validate an existing index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + # Add to new index. + if config.retro_task_validate is None: + _add_to_index(config) + + # Validate existing encodings. + else: + from .validate import validate_added_encodings + + validate_added_encodings(config) + + +################################################## +# Build index (train + add). +################################################## + + +def build_index(config: RetroPreprocessingConfig) -> None: + """Build index. + + Building index involves sequentially running stages above: + - Train index (on sampled training chunks). + - Add to index (on all training chunks). + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + # Train index. + train_index(config) + + # Add to index. + add_to_index(config) diff --git a/megatron/core/datasets/retro/index/factory.py b/megatron/core/datasets/retro/index/factory.py new file mode 100644 index 0000000000..f88084ddb1 --- /dev/null +++ b/megatron/core/datasets/retro/index/factory.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""The IndexFactory constructs an index from an index type string.""" + +from megatron.core.datasets.retro.index.index import Index + +from .indexes import FaissBaseIndex, FaissParallelAddIndex + + +class IndexFactory: + """Get index. + + Index type generally read from argument '--retro-index-ty'. + """ + + @classmethod + def get_index_class(cls, index_type: str) -> type: + """Get an index class, given a type string. + + Args: + index_type (str): One of 'faiss-base' (naive Faiss index wrapper) or 'faiss-par-add' (Faiss index wrapper with near embarrassingly parallel index.add(). + + Returns: + An `Index` sub-type corresponding to the `index_type`. + """ + return {"faiss-base": FaissBaseIndex, "faiss-par-add": FaissParallelAddIndex}[index_type] + + @classmethod + def get_index(cls, index_type: str) -> Index: + """Construct an index from an index type string. + + Args: + index_type (str): One of 'faiss-base' (naive Faiss index wrapper) or 'faiss-par-add' (Faiss index wrapper with near embarrassingly parallel index.add(). + + Returns: + An `Index` instance corresponding to the `index_type`. + """ + index_class = cls.get_index_class(index_type) + index = index_class() + return index diff --git a/megatron/core/datasets/retro/index/index.py b/megatron/core/datasets/retro/index/index.py new file mode 100644 index 0000000000..c6bd13fbee --- /dev/null +++ b/megatron/core/datasets/retro/index/index.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Base class for all vector indexes. + +A vector index is a type of retrieval database that is queried using vectors, +and returns vectors that are 'similar' (e.g., by cosine distance) to the query +vector. The construction and usage of an index generally has the following +pattern: + + - Train the index on representative vectors. + - Add vectors to the index (i.e., vectors available for retrieval) + - Query index with new vector, to retrieve similar vector indexes. +""" + +import abc +import os +from typing import List, Tuple + +import numpy as np +import torch + +from megatron.core.datasets.retro.config import Embedder, RetroPreprocessingConfig +from megatron.core.datasets.retro.external_libs import faiss +from megatron.core.datasets.retro.utils import GPTToTextDataset + +from .utils import get_index_dir + + +class Index(abc.ABC): + """Abstract base class for indexes. + + *Note* : While currently only Faiss-based classes are implemented, in the + future, this class will be extended with other types of indexes that have + different performance-accuracy trade-offs. + + The primary methods to override are: + - train() : Train index on the sampled training chunks. + - add() : Add all training chunks to index. + """ + + @classmethod + def make_object_verbose(cls, index: faiss.Index, verbose: bool) -> None: + """Make index object verbose. + + Args: + index (faiss.Index): Faiss object to set verbose. + verbose (bool): Sets whether index should log status updates during training and adding. + """ + assert isinstance(verbose, bool) + faiss.ParameterSpace().set_index_parameter(index, "verbose", verbose) + + def get_empty_index_path(self, config: RetroPreprocessingConfig) -> str: + """Get file path to empty index (i.e., trained, but unpopulated). + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + File path to empty index (i.e., this index has had index.train() called, but not yet index.add()). + """ + return os.path.join( + get_index_dir(config), "empty_%.3f.faissindex" % config.retro_index_train_load_fraction + ) + + def get_empty_index(self, config: RetroPreprocessingConfig) -> faiss.Index: + """Get empty index (i.e., trained, but unpopulated). + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + Empty Faiss index, loaded from storage. + """ + return faiss.read_index(self.get_empty_index_path(config)) + + def get_added_index_path(self, config: RetroPreprocessingConfig) -> str: + """Get file path to index that has been populated with vectors. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + File path to added index (i.e., this index has had both index.train() and index.add() called). + """ + return os.path.join( + get_index_dir(config), + "added_%.3f_%.3f.faissindex" + % (config.retro_index_train_load_fraction, config.retro_index_add_load_fraction), + ) + + def get_added_index(self, config: RetroPreprocessingConfig) -> faiss.Index: + """Get index that has been populated with vectors. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + 'Added' (i.e., populated) Faiss index, loaded from storage. + """ + return faiss.read_index(self.get_added_index_path(config)) + + @abc.abstractmethod + def train(self, config: RetroPreprocessingConfig) -> None: + """Train index on a representative set of vectors. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + @abc.abstractmethod + def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None: + """Add vectors to index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index. + """ + + def embed_text_dataset_block( + self, embedder: Embedder, text_dataset: GPTToTextDataset, _range: Tuple[int, int] + ) -> np.ndarray: + """Embed a range of a text dataset. + + Args: + embedder (Embedder): Embedder used for embedding a text dataset. + text_dataset (GPTToTextDataset): Text dataset that will be embedded. + _range (Tuple[int, int]): Start/end sample indices within text dataset used for embedding. + + Returns: + An array of embeddings, with shape (len(text_dataset), dimension(embedder)). + """ + sub_dataset = torch.utils.data.Subset(text_dataset, range(*_range)) + return embedder.embed_text_dataset(sub_dataset) diff --git a/megatron/core/datasets/retro/index/indexes/__init__.py b/megatron/core/datasets/retro/index/indexes/__init__.py new file mode 100644 index 0000000000..c445909fea --- /dev/null +++ b/megatron/core/datasets/retro/index/indexes/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Exports: +- FaissBaseIndex: Unoptimized Faiss index wrapper +- FaissParallelAddIndex: Optimized index.add() for Faiss index. +""" + +from .faiss_base import FaissBaseIndex +from .faiss_par_add import FaissParallelAddIndex diff --git a/megatron/core/datasets/retro/index/indexes/faiss_base.py b/megatron/core/datasets/retro/index/indexes/faiss_base.py new file mode 100644 index 0000000000..c1daf3f533 --- /dev/null +++ b/megatron/core/datasets/retro/index/indexes/faiss_base.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +This class implements a simple, un-optimized wrapper around a Faiss index, that +implements the Index interface (see ..index.py). While this class is +instantiable, it is meant to be extended with optimizations in classes that +inherit from this class (see FaissParAddIndex, for an example). +""" + +import os + +import numpy as np +import torch +from tqdm import tqdm + +from megatron.core.datasets.retro.config import RetroPreprocessingConfig +from megatron.core.datasets.retro.external_libs import faiss +from megatron.core.datasets.retro.index.index import Index +from megatron.core.datasets.retro.index.utils import ( + get_training_data_merged_path, + num_samples_to_block_ranges, +) +from megatron.core.datasets.retro.utils import GPTToTextDataset, log_retro_rank_0 + + +class FaissBaseIndex(Index): + """Base class for Faiss-base indexes. + + This class wraps a Faiss index, and adds additional functionality for training + and adding codes. This base class performs a naive sequential code adding, + while the optimized FaissParallelAddIndex class performs a parallel + index.add(). + """ + + def _train(self, config: RetroPreprocessingConfig) -> None: + """Train index (rank 0's method). + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + assert torch.distributed.get_rank() == 0 + + # Set num threads (torch.distributed reset it to 1). + faiss.omp_set_num_threads(64) + + empty_index_path = self.get_empty_index_path(config) + + # Index already exists? -> return. + if os.path.isfile(empty_index_path): + return + + # Load data. + merged_path = get_training_data_merged_path(config) + inp = np.memmap(merged_path, dtype="f4", mode="r").reshape((-1, config.hidden_size)) + + # Init index. + index = faiss.index_factory(config.hidden_size, config.retro_index_str) + + # Move to GPU. + log_retro_rank_0("> move faiss index to gpu.") + index_ivf = faiss.extract_index_ivf(index) + clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d)) + index_ivf.clustering_index = clustering_index + log_retro_rank_0("> finished moving to gpu.") + self.make_object_verbose(index, True) + self.make_object_verbose(index_ivf, True) + self.make_object_verbose(index_ivf.quantizer, True) + self.make_object_verbose(index_ivf.clustering_index, True) + + # Train index. + index.train(inp) + + # Save index. + faiss.write_index(index, empty_index_path) + + def train(self, config: RetroPreprocessingConfig) -> None: + """Train index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + # Single process only. + if torch.distributed.get_rank() == 0: + self._train(config) + + torch.distributed.barrier() + + def _add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None: + """Add to index (rank 0's method). + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index. + """ + + assert torch.distributed.get_rank() == 0 + + dataset_sample_ranges = num_samples_to_block_ranges(len(text_dataset)) + + # Set num threads (torch.distributed reset it to 1). + faiss.omp_set_num_threads(64) + + # Bert embedder. + embedder = config.bert_embedders.mem + + # Empty/added index paths. + empty_index_path = self.get_empty_index_path() + added_index_path = self.get_added_index_path() + + # Skip adding, if index exists. + if os.path.isfile(added_index_path): + return + + # Read trained index. + index = faiss.read_index(empty_index_path) + + # Iterate data blocks & add. + for sample_range in tqdm(dataset_sample_ranges, "faiss_base.add"): + + # Embed text. + embeds = self.embed_text_dataset_block(embedder, text_dataset, sample_range) + + # Add to index. + index.add(embeds) + + # Write index. + faiss.write_index(index, added_index_path) + + def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> str: + """Add to index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index. + + Returns: + File path to the populated index. + """ + + # Single process only. + if torch.distributed.get_rank() == 0: + self._add(config, text_dataset) + + # Wait for rank 0. + torch.distributed.barrier() + + # Get output index path, for return. + return self.get_added_index_path(config) diff --git a/megatron/core/datasets/retro/index/indexes/faiss_par_add.py b/megatron/core/datasets/retro/index/indexes/faiss_par_add.py new file mode 100644 index 0000000000..e014217262 --- /dev/null +++ b/megatron/core/datasets/retro/index/indexes/faiss_par_add.py @@ -0,0 +1,208 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Multi-process & multi-node version of Faiss's index.add(). + +This class inherits from FaissBaseIndex, and optimizes the 'add()' method by +making it multi-node and multi-process, with bit-wise equivalence to +FaissBaseIndex. This allows 'add()' to scale out to very large datasets, since +the vast majority of the computational effort is embarrassingly parallel. +""" + +import os +import shutil +from typing import Tuple + +import numpy as np +import psutil +import torch +from tqdm import tqdm + +from megatron.core.datasets.retro.config import Embedder, RetroPreprocessingConfig +from megatron.core.datasets.retro.external_libs import faiss, h5py +from megatron.core.datasets.retro.index.utils import get_added_code_paths, get_added_codes_dir +from megatron.core.datasets.retro.utils import ( + GPTToTextDataset, + get_blocks_by_rank, + log_retro_rank_0, + retro_makedir, +) + +from .faiss_base import FaissBaseIndex + + +class FaissParallelAddIndex(FaissBaseIndex): + """ + This class parallelizes both 1) encoding vectors, and 2) adding codes to the + index. This class is more performant than naive use of Faiss, because most + of the computational work is in encoding the vectors, which is an + embarassingly parallel operation. + """ + + def encode_block( + self, index: faiss.Index, embedder: Embedder, text_dataset: GPTToTextDataset, block: dict + ) -> Tuple[np.ndarray, np.ndarray]: + """Encode sub-dataset block, to be later added to index. + + Encode the data subset, generally in blocks of 1M vectors each. For + each block, the empty/trained index is loaded, codes are computed + via index.sa_encode(), and the resulting codes are saved to disk. + + Args: + index (faiss.Index): Faiss index object. + embedder (Embedder): Embedder used to embed text dataset. + text_dataset (GPTToTextDataset): Text dataset to be embedded and encoded. + block (dict): Range information specifying start/end indices within text dataset. + + Returns: + A tuple of (embeddings, encodings) for the given block subset of the text dataset. + """ + + # Embed block. + embeddings = self.embed_text_dataset_block(embedder, text_dataset, block["range"]) + + # Encode block. + log_retro_rank_0("encode.") + codes = index.sa_encode(embeddings) + + # Return embeddings for validation purposes. + return embeddings, codes + + def save_block(self, config: RetroPreprocessingConfig, block: dict, codes: np.ndarray) -> None: + """Save block of codes to disk. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + block (dict): Range information specifying the start/end indices within the encoded text dataset. Here, the 'path' item is used for writing the encodings to storage. + codes (np.ndarray): Block of encodings to be saved to storage. + """ + # Save neighbors. + log_retro_rank_0("save codes.") + retro_makedir(config, os.path.dirname(block["path"])) + with h5py.File(block["path"], "w") as f: + f.create_dataset("data", data=codes) + + def encode(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None: + """Encode text dataset, to be later added to index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + text_dataset (GPTToTextDataset): Text dataset to be encoded by the index. + """ + + codes_dir = get_added_codes_dir(config) + retro_makedir(config, codes_dir) + + # Index. + index = self.get_empty_index(config) + + # Bert embedder. + embedder = config.retro_bert_embedders.mem + + # Missing code blocks. + def validate(f: h5py.File) -> None: + """Validation method for validating loaded encodings. + + Args: + f (h5py.File): File that contains encodings. + """ + assert len(f["data"].shape) == 2 + + blocks = get_blocks_by_rank( + codes_dir, len(text_dataset), config.retro_block_size, validate=validate + ) + + # Encode each block. + for block_index, block in enumerate(blocks.missing): + + if block is not None: + + # Progress. + log_retro_rank_0( + "encode block %d / %d ... %s." + % (block_index, len(blocks.missing), block["path"]) + ) + + # Encode and save. + _, codes = self.encode_block(index, embedder, text_dataset, block) + self.save_block(config, block, codes) + + # Synchronize progress across all ranks. (for easier observation) + log_retro_rank_0(" > waiting for other ranks to finish block.") + torch.distributed.barrier() + + def add_codes(self, config: RetroPreprocessingConfig) -> None: + """Read codes from disk, and add them to the index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + if torch.distributed.get_rank() != 0: + return + + added_index_path = self.get_added_index_path(config) + if os.path.exists(added_index_path): + return + + # Index. + log_retro_rank_0("read empty index.") + index = self.get_empty_index(config) + index_ivf = faiss.extract_index_ivf(index) + + # Add codes. + log_retro_rank_0("add codes.") + code_paths = get_added_code_paths(config) + pbar = tqdm(code_paths) + for code_path in pbar: + pbar.set_description( + "add codes, mem %.3f gb, %.1f%%" + % (psutil.virtual_memory()[3] / 1024**3, psutil.virtual_memory()[2]) + ) + with h5py.File(code_path) as f: + + nload = int(config.retro_index_add_load_fraction * f["data"].shape[0]) + offset = int(os.path.basename(code_path).split("-")[0]) + xids = np.arange(offset, offset + nload) + codes = np.copy(f["data"][:nload]) + index_ivf.add_sa_codes(codes, xids) + + # Update index's ntotal. + index.ntotal = index_ivf.ntotal + + # Write index. + log_retro_rank_0("write added index.") + faiss.write_index(index, added_index_path) + + def remove_codes(self, config: RetroPreprocessingConfig) -> None: + """Remove added codes after adding to index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + if torch.distributed.get_rank() != 0: + return + assert os.path.isfile(self.get_added_index_path(config)) + + if config.retro_index_delete_added_codes: + raise Exception("remove?") + shutil.rmtree(get_added_codes_dir(config), ignore_errors=True) + + def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None: + """Add vectors to index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index. + """ + + # Encode chunks. + self.encode(config, text_dataset) + + # Add codes to index. + self.add_codes(config) + + # Wait for (single-process) adding to complete. + torch.distributed.barrier() + + # Remove codes. + self.remove_codes(config) diff --git a/megatron/core/datasets/retro/index/utils.py b/megatron/core/datasets/retro/index/utils.py new file mode 100644 index 0000000000..58229439ae --- /dev/null +++ b/megatron/core/datasets/retro/index/utils.py @@ -0,0 +1,126 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Utilities for building an index.""" + +import glob +import os +from typing import List, Tuple + +from megatron.core.datasets.retro.config import RetroPreprocessingConfig +from megatron.core.datasets.retro.utils import retro_makedir + + +def get_index_dir(config: RetroPreprocessingConfig) -> str: + """Create sub-directory for this index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + Path to index sub-directory within Retro project. + """ + + # Directory path. + index_dir_path = os.path.join( + config.retro_project_dir, "index", config.retro_index_type, config.retro_index_str + ) + + # Make directory. + retro_makedir(config, index_dir_path) + + return index_dir_path + + +def num_samples_to_block_ranges( + config: RetroPreprocessingConfig, num_samples: int +) -> List[Tuple[int, int]]: + """Split a range (length num_samples) into sequence of block ranges + of size block_size. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + num_samples (int): Split `num_samples` into consecutive block ranges, where each block is size `config.retro_block_size`. + + Returns: + A list of tuples where each item is the (start, end) index for a given block. + """ + block_size = config.retro_block_size + start_idxs = list(range(0, num_samples, block_size)) + end_idxs = [min(num_samples, s + block_size) for s in start_idxs] + ranges = list(zip(start_idxs, end_idxs)) + return ranges + + +def get_training_data_root_dir(config: RetroPreprocessingConfig) -> str: + """Get root directory for embeddings (blocks and merged data). + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + Path to the training data directory, which contains both training embedding blocks and the final merged training embeddings. + """ + return os.path.join(config.retro_project_dir, "index", "train_emb") + + +def get_training_data_block_dir(config: RetroPreprocessingConfig) -> str: + """Get directory for of saved embedding blocks. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + Path to the directory containing the training embedding blocks, which will be later merged into a single embedding array. + """ + return os.path.join(get_training_data_root_dir(config), "blocks") + + +def get_training_data_block_paths(config: RetroPreprocessingConfig) -> List[str]: + """Get paths to saved embedding blocks. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + Paths of all training embedding blocks. + """ + return sorted(glob.glob(get_training_data_block_dir(config) + "/*.hdf5")) + + +def get_training_data_merged_path(config: RetroPreprocessingConfig) -> str: + """Get path to merged training embeddings. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + Path to the merged training embedding binary file. + """ + return os.path.join( + get_training_data_root_dir(config), + "train_%.3f.bin" % config.retro_index_train_load_fraction, + ) + + +def get_added_codes_dir(config: RetroPreprocessingConfig) -> str: + """Get directory of saved encodings. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + Path to the directory containing the vector encodings for adding to the index. + """ + return os.path.join(get_index_dir(config), "add_codes") + + +def get_added_code_paths(config: RetroPreprocessingConfig) -> List[str]: + """Get paths to all saved encodings. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + Paths of all vector encoding blocks, for adding to the index. + """ + return sorted(glob.glob(get_added_codes_dir(config) + "/*.hdf5")) diff --git a/megatron/core/datasets/retro/index/validate.py b/megatron/core/datasets/retro/index/validate.py new file mode 100644 index 0000000000..57306707c4 --- /dev/null +++ b/megatron/core/datasets/retro/index/validate.py @@ -0,0 +1,191 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Validate an index's data. + +This module contains functionality for checking for bitwise equality across code +changes. The training and adding steps of index construction can be validated +separately. The following high-level checks are supported: + + - Training: Validate that saved training embeddings are bitwise equal with a + sample set of freshly computed embeddings. (*Note*: + `--no-retro-index-delete-training-embeddings` must be used.) + - Adding: Validate that the saved encodings are bitwise equal with a sample of + sample set of freshly computed encodings. (*Note*: + `--no-retro-index-delete-added-codes` must be used.) +""" + +import typing + +import numpy as np +import torch +from torch.utils.data import Subset + +from megatron.core.datasets.retro.config import RetroPreprocessingConfig +from megatron.core.datasets.retro.external_libs import h5py +from megatron.core.datasets.retro.utils import ( + GPTToTextDataset, + get_blocks_by_rank, + log_retro_rank_0, +) + +from .build import get_text_dataset_for_adding, get_text_dataset_for_training +from .factory import IndexFactory +from .utils import get_added_codes_dir, get_training_data_block_dir + +################################################## +# Validate trained index. +################################################## + + +def validate_training_embeddings(config: RetroPreprocessingConfig) -> None: + """Validate training embeddings. + + Steps: + - Randomly sample subset of text dataset blocks. + - Embed each block. + - Compare against saved embeddings. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + # Training text dataset. + text_dataset = get_text_dataset_for_training(config) + + # Sample existing blocks. + blocks = get_blocks_by_rank( + dirname=get_training_data_block_dir(config), + n_samples=len(text_dataset), + block_size=config.retro_block_size, + validate=None, + sample=config.retro_task_validate, + ) + + assert blocks.n_missing_world == 0 + + # Embed & validate blocks. + embedder = config.retro_bert_embedders.mem + for block_idx, block in enumerate(blocks.existing): + + # Missing block lists are extended with None to have equal-length + # lists. Skip the Nones. + if block is not None: + + # Progress. (*note*: move world progress to here.) + log_retro_rank_0( + "embed training block %d / %d ... %s." + % (block_idx, len(blocks.existing), block["path"]) + ) + + # Load existing block embeddings. + with h5py.File(block["path"]) as f: + existing_embeddings = np.copy(f["data"]) + + # Embed block. + sub_dataset = Subset(text_dataset, range(*block["range"])) + embeddings = embedder.embed_text_dataset(sub_dataset, "train") + + # Check equality. + log_retro_rank_0(" > validate.") + assert np.array_equal(existing_embeddings, embeddings) + + # Synchronize progress across all ranks. (for easier observation) + log_retro_rank_0(" > waiting for other ranks to finish block.") + torch.distributed.barrier() + + log_retro_rank_0(" > finished validating training embeddings.") + + +################################################## +# Validate filled index. +################################################## + + +def validate_added_encodings(config: RetroPreprocessingConfig) -> None: + """Validate added encodings. + + Steps: + - Randomly sample subset of text dataset blocks. + - Encode each block. + - Compare against saved encodings. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + # Index. + index = IndexFactory.get_index(config.retro_index_type) + inner_index = index.get_empty_index(config) + + # Text dataset. + text_dataset = get_text_dataset_for_adding(config) + + # Sample existing blocks. + def validate(f: h5py.File) -> None: + """Validation method for validating encoding blocks. + + Args: + f (h5py.File): File with block of encodings. + """ + assert len(f["data"].shape) == 2 + + blocks = get_blocks_by_rank( + dirname=get_added_codes_dir(config), + n_samples=len(text_dataset), + block_size=config.retro_block_size, + validate=validate, + sample=config.retro_task_validate, + ) + + assert blocks.n_missing_world == 0 + + # Encode and validate blocks. + embedder = config.retro_bert_embedders.mem + for block_idx, block in enumerate(blocks.existing): + + if block is not None: + + # Progress. + log_retro_rank_0( + "encode block %d / %d ... %s." % (block_idx, len(blocks.existing), block["path"]) + ) + + # Load existing codes. + with h5py.File(block["path"]) as f: + existing_codes = np.copy(f["data"]) + + # Encode block. + embeddings, codes = index.encode_block(inner_index, embedder, text_dataset, block) + + # Check equality. + log_retro_rank_0(" > validate.") + assert np.array_equal(existing_codes, codes) + + # Synchronize progress across all ranks. (for easier observation) + log_retro_rank_0(" > waiting for other ranks to finish block.") + torch.distributed.barrier() + + log_retro_rank_0(" > finished validating added encodings.") + + +################################################## +# Validate index (trained + filled). +################################################## + + +def validate_index(config: RetroPreprocessingConfig) -> None: + """Validate index. + + Validating index involves sequentially running stages above: + - Validate trained index. + - Validate filled index. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + # Validate training embeddings. + validate_training_embeddings(config) + + # Validate added codes. + validate_added_encodings(config) diff --git a/megatron/core/datasets/retro/query/__init__.py b/megatron/core/datasets/retro/query/__init__.py new file mode 100644 index 0000000000..ac9483373c --- /dev/null +++ b/megatron/core/datasets/retro/query/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/datasets/retro/query/gpt_chunk_dataset.py b/megatron/core/datasets/retro/query/gpt_chunk_dataset.py new file mode 100644 index 0000000000..6191a30a31 --- /dev/null +++ b/megatron/core/datasets/retro/query/gpt_chunk_dataset.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +A GPTChunkDataset is a wrapper around a regular GPTDataset, that sequentially +chunks the sample tokens into `retro_chunk_length` sized smaller samples. + +For example, if the GPTDataset has 100 samples and a sequence length of 2048, and +retro_chunk_length is 64, then the GPTChunkDataset will contain 100*(2048/64) = +3200 samples, each with length 64. +""" + +import torch + +from megatron.core.datasets.gpt_dataset import GPTDataset +from megatron.core.datasets.retro.utils import get_num_chunks_per_sample + +from .utils import get_neighbor_dir + + +class GPTChunkDataset(torch.utils.data.Dataset): + """Pretraining chunk dataset wraps a standard GPT dataset. + + This dataset conceptually divides each sample (e.g., length 2048) + into chunks (e.g., length 64) and restructures them into a list of + chunks (e.g., length num_samples * num_chunks_per_sample). + + Args: + sample_dataset (GPTDataset): Original GPT dataset, with `sequence_length` size samples. + sample_length (int): Alias for `sequence_length`. + chunk_length (int): Retro chunk length (e.g., 64). + """ + + def __init__(self, sample_dataset: GPTDataset, sample_length: int, chunk_length: int): + + super().__init__() + + self.sample_dataset = sample_dataset + self.chunk_length = chunk_length + self.n_chunks_per_sample = get_num_chunks_per_sample(sample_length, chunk_length) + self.n_samples = len(sample_dataset) + self.n_chunks = self.n_samples * self.n_chunks_per_sample + + def __len__(self) -> int: + """Get dataset length. + + Returns: + Dataset length. + """ + return self.n_chunks + + def __getitem__(self, idx: int) -> dict: + """Get sample, including represented document IDs. + + Args: + idx (int): Sample index. + + Returns: + A sample, which contains both the chunk-length token sample ('text') along with all document_ids ('doc_ids') contained withing the full `sequence_length` sample. + """ + + # Convert global chunk index to global sample index & local chunk index. + sample_idx = idx // self.n_chunks_per_sample + chunk_idx = idx % self.n_chunks_per_sample + + # Extract sample data. + sample = self.sample_dataset[sample_idx] + sample_token_ids = sample["text"] + sample_doc_ids = sample["document_ids"] + + # Chunk start/end token idxs. + token_start_idx = chunk_idx * self.chunk_length + token_end_idx = token_start_idx + self.chunk_length + chunk_token_ids = sample_token_ids[token_start_idx:token_end_idx] + + # Sample. + return {"doc_ids": sample_doc_ids, "text": chunk_token_ids} + + +def build_gpt_chunk_datasets_from_gpt_datasets( + project_dir: str, gpt_datasets: dict, sample_length: int, chunk_length: int +) -> dict: + """Get train, valid, test GPT chunk datasets. + + Args: + project_dir (str): Retro project dir. + gpt_datasets (dict): Mapping of 'train', 'valid', and 'test' GPT datasets (original, unchunked datasets). + sample_length (int): Alias of `sequence_length`. + chunk_length (int): Retro chunk length (e.g., 64). + + Returns: + A ? + """ + + # GPT chunk datasets. + chunk_datasets = { + key: ( + { + "dataset": GPTChunkDataset(sample_ds, sample_length, chunk_length), + "neighbor_dir": get_neighbor_dir(project_dir, key, sample_ds), + "num_active_chunks": num_active_samples + * get_num_chunks_per_sample(sample_length, chunk_length), + } + if sample_ds + else None + ) + for key, (sample_ds, num_active_samples) in gpt_datasets.items() + } + + return chunk_datasets diff --git a/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py b/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py new file mode 100644 index 0000000000..97a891fd14 --- /dev/null +++ b/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py @@ -0,0 +1,107 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""A MultiSplitGPTDataset can handle multiple intersecting split strings, as well +as returning all of the document IDs of a sample.""" + +import logging +from dataclasses import dataclass +from typing import Dict, List + +import numpy + +from megatron.core.datasets.blended_megatron_dataset_config import ( + convert_split_vector_to_split_matrix, + parse_and_normalize_split, +) +from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig +from megatron.core.datasets.indexed_dataset import IndexedDataset +from megatron.core.datasets.utils import Split +from megatron.core.utils import log_single_rank + +logger = logging.getLogger(__name__) + + +@dataclass +class MultiSplitGPTDatasetConfig(GPTDatasetConfig): + """Configuration object for Megatron Core blended and Retro datasets. + + Args: + return_document_ids (bool): Whether to return the document ids when querying the dataset. Turn this option on during preprocessing. + split_preprocessing (str): The Retro preprocessing split string. It follows the same pattern convention as 'split'. Not to be used with 'blend_per_split'. + """ + + return_document_ids: bool = None + + split_preprocessing: str = None + + def __post_init__(self) -> None: + """Validate config attributes.""" + super().__post_init__() + assert self.split is not None, "the Retro data pipeline does not support 'blend_per_split'" + assert self.return_document_ids is not None, "this attribute must be user defined" + assert self.split_preprocessing is not None, "this attribute must be user defined" + split_vector = parse_and_normalize_split(self.split) + split_preprocessing_vector = parse_and_normalize_split(self.split_preprocessing) + if not numpy.allclose(split_vector, split_preprocessing_vector): + self.split_matrix = convert_split_vector_to_split_matrix( + split_vector, split_preprocessing_vector + ) + log_single_rank( + logger, + logging.WARNING, + f"split =/= split_preprocessing. Let split_matrix = {self.split_matrix}", + ) + + +class MultiSplitGPTDataset(GPTDataset): + """Retro's customized GPT dataset. + + Args: + indexed_dataset (IndexedDataset): The IndexedDataset around which to build the MegatronDataset. + dataset_path (str): The real path on disk to the dataset, for bookkeeping. + indexed_indices (numpy.ndarray): The set of the documents indices to expose. + num_samples (int): The number of samples to draw from the indexed dataset. + index_split (Split): The indexed_indices Split. + config (MultiSplitGPTDatasetConfig): The Retro-specific container for all config sourced parameters. + """ + + def __init__( + self, + indexed_dataset: IndexedDataset, + dataset_path: str, + indexed_indices: numpy.ndarray, + num_samples: int, + index_split: Split, + config: MultiSplitGPTDatasetConfig, + ) -> None: + super().__init__( + indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config + ) + + def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: + """Get dataset sample. + + Args: + idx (int): The index into the dataset. + + Returns: + Dict[str, numpy.ndarray]: The text ids and (optionally) the document ids wrapped in a dictionary. + """ + text, document_ids = self._query_document_sample_shuffle_indices(idx) + if self.config.return_document_ids: + return {"text": text, "document_ids": document_ids} + else: + return {"text": text} + + @staticmethod + def _key_config_attributes() -> List[str]: + """Add custom attributes for building unique dataset hash. + + The preprocessing split used for preprocessing will constrain the samples available for pretraining. + + Returns: + List[str]: The key config attributes. + """ + return super(MultiSplitGPTDataset, MultiSplitGPTDataset)._key_config_attributes() + [ + "split_preprocessing" + ] diff --git a/megatron/core/datasets/retro/query/query.py b/megatron/core/datasets/retro/query/query.py new file mode 100644 index 0000000000..9da3381712 --- /dev/null +++ b/megatron/core/datasets/retro/query/query.py @@ -0,0 +1,393 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Entry point for querying an index using a GPTChunkDataset. + +Querying involves: + + - Iterate all chunks in the GPTChunkDataset. + - Query index for neighbor chunk IDs (i.e., chunks from the chunk database). + - Save neighbor chunk IDs to disk, for use in building a RetroDataset sample + during pretraining. +""" + +import os +import time +import typing + +import numpy as np +import psutil +import torch +from tqdm import tqdm + +from megatron.core.datasets.retro.config import RetroPreprocessingConfig +from megatron.core.datasets.retro.db.dataset import DBDataset +from megatron.core.datasets.retro.db.utils import ( + get_merged_train_dataset as get_db_merged_train_dataset, +) +from megatron.core.datasets.retro.external_libs import faiss, h5py +from megatron.core.datasets.retro.index.factory import IndexFactory +from megatron.core.datasets.retro.index.index import Index +from megatron.core.datasets.retro.index.utils import get_index_dir +from megatron.core.datasets.retro.query.gpt_chunk_dataset import GPTChunkDataset +from megatron.core.datasets.retro.utils import ( + GPTToTextDataset, + get_blocks_by_rank, + log_retro_rank_0, + retro_makedir, +) + +from .gpt_chunk_dataset import build_gpt_chunk_datasets_from_gpt_datasets + + +def get_index(config: RetroPreprocessingConfig, ondisk: bool = False) -> faiss.Index: + """Read index from disk. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + ondisk (bool): If `ondisk = True`, memory map the index. (For debugging purposes only; very non-performant.) + + Returns: + A Faiss index, loaded from storage. + """ + + # Load index. + index_wrapper = IndexFactory.get_index(config.retro_index_type) + index_dir = get_index_dir(config) + added_index_path = index_wrapper.get_added_index_path(config) + if ondisk: + index = faiss.read_index(added_index_path, faiss.IO_FLAG_MMAP) + else: + index = faiss.read_index(added_index_path) + + # Search parameters. + faiss.ParameterSpace().set_index_parameter(index, "efSearch", config.retro_query_ef_search) + faiss.ParameterSpace().set_index_parameter(index, "nprobe", config.retro_query_nprobe) + + return index + + +def embed_block( + config: RetroPreprocessingConfig, gpt_dataset: GPTChunkDataset, block: dict +) -> np.ndarray: + """Embed block of chunks. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + gpt_dataset (GPTChunkDataset): Chunk dataset to be embedded. + block (dict): Range information containing start/end indices of subset of chunk dataset. + + Returns: + Embeddings array, with shape (len(block["range"]), dimension(embedder)). + """ + text_block_dataset = torch.utils.data.Subset( + GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt), range(*block["range"]) + ) + return config.retro_bert_embedders.mem.embed_text_dataset(text_block_dataset) + + +def query_embeddings( + config: RetroPreprocessingConfig, + db_dataset: DBDataset, + index: Index, + embeddings: np.ndarray, + chunk_id_range: range, + sample_map: dict, + n_chunks_per_sample: int, + verbose: bool = True, +) -> typing.Tuple[np.ndarray, np.ndarray]: + """Query neighbors of a block of embeddings. + + Querying includes: + - Query index for neighbor chunk IDs. + - Filter chunk IDs that have the same document ID as the queried embedding. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + db_dataset (DBDataset): Dataset containing chunk database entries. + index (Index): Vector index populated with chunk database indices. + embeddings (np.ndarray): Embeddings from GPT chunk dataset. + chunk_id_range (range): Chunk ID range from GPT chunk dataset. + sample_map (dict): Mapping of sample_idx to dataset_idx and document_ids. Used for document filtering. + n_chunks_per_sample (int): Number of chunks per sample (e.g., sequence_length / chunk_length). + verbose (bool): Log querying progress. + + Returns: + A tuple of original (unfiltered) neighbor IDs, and filtered (by document ID) neighbor IDs. + """ + + # Query neighbor ids. + if verbose: + log_retro_rank_0("search.") + t = time.time() + assert index.ntotal > 0, "check we don't accidentally have an empty index." + _, query_neighbor_ids = index.search(embeddings, config.retro_query_num_neighbors_query) + if verbose: + log_retro_rank_0(" time : %.3f sec." % (time.time() - t)) + + # Filter banned neighbor ids. + if verbose: + log_retro_rank_0("filter banned neighbor ids.") + filtered_neighbor_ids = np.full( + shape=(len(query_neighbor_ids), config.retro_query_num_neighbors_save), + fill_value=-1, + dtype="int64", + ) + min_chunk_id, max_chunk_id = chunk_id_range + for chunk_id in range(min_chunk_id, max_chunk_id): + + sample_id = chunk_id // n_chunks_per_sample + sample = sample_map[sample_id] + sample_dataset_idx = sample["dataset_idx"].item() + sample_doc_ids = sample["doc_ids"].tolist() + sample_doc_tuples = [(sample_dataset_idx, d) for d in sample_doc_ids] + + # Get valid neighbors (!= -1). + query_row = [i for i in query_neighbor_ids[chunk_id - min_chunk_id] if i >= 0] + + # Filter row. + filtered_row = [ + i + for i in query_row + if tuple(db_dataset.doc_tuples[i].tolist()) not in sample_doc_tuples + ] + filtered_row = filtered_row[: config.retro_query_num_neighbors_save] + filtered_row += [-1] * (config.retro_query_num_neighbors_save - len(filtered_row)) + filtered_neighbor_ids[chunk_id - min_chunk_id] = filtered_row + + return query_neighbor_ids, filtered_neighbor_ids + + +def query_embedding_block( + config: RetroPreprocessingConfig, + db_dataset: DBDataset, + index: Index, + embeddings: np.ndarray, + chunk_id_range: range, + sample_map: dict, + n_chunks_per_sample: int, +) -> typing.Tuple[np.ndarray, np.ndarray]: + """Query a block of embeddings. + + The block is broken into smaller sub-blocks, for easier tracking of progress. + Both the raw neighbor IDs and the filtered neighbor IDs (i.e., chunks with the + same document ID are removed) are collected. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + db_dataset (DBDataset): Dataset containing chunk database entries. + index (Index): Vector index populated with chunk database indices. + embeddings (np.ndarray): Embeddings from GPT chunk dataset. + chunk_id_range (range): Chunk ID range from GPT chunk dataset. + sample_map (dict): Mapping of sample_idx to dataset_idx and document_ids. Used for document filtering. + n_chunks_per_sample (int): Number of chunks per sample (e.g., sequence_length / chunk_length). + + Returns: + A tuple of original (unfiltered) neighbor IDs, and filtered (by document ID) neighbor IDs. + """ + + query_neighbor_ids = [] + filtered_neighbor_ids = [] + + # Query in sub-blocks. + partial_block_size = 1000 + for partial_start_idx in tqdm( + range(0, len(embeddings), partial_block_size), + " search", + miniters=(len(embeddings) // partial_block_size) // 10, + disable=torch.distributed.get_rank() != 0, + ): + partial_end_idx = min(len(embeddings), partial_start_idx + partial_block_size) + partial_embeddings = embeddings[partial_start_idx:partial_end_idx] + partial_chunk_id_range = ( + chunk_id_range[0] + partial_start_idx, + chunk_id_range[0] + partial_end_idx, + ) + partial_query_neighbor_ids, partial_filtered_neighbor_ids = query_embeddings( + config, + db_dataset, + index, + partial_embeddings, + partial_chunk_id_range, + sample_map, + n_chunks_per_sample, + verbose=False, + ) + query_neighbor_ids.append(partial_query_neighbor_ids) + filtered_neighbor_ids.append(partial_filtered_neighbor_ids) + + # Concatenate. + query_neighbor_ids = np.concatenate(query_neighbor_ids, axis=0) + filtered_neighbor_ids = np.concatenate(filtered_neighbor_ids, axis=0) + + return query_neighbor_ids, filtered_neighbor_ids + + +def query_block_neighbors( + config: RetroPreprocessingConfig, + db_dataset: DBDataset, + query_dataset: GPTChunkDataset, + index: Index, + block: dict, +) -> None: + """Query neighbors of a dataset block (i.e., range). + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + db_dataset (DBDataset): Dataset containing chunk database entries. + query_dataset (GPTChunkDataset): GPT chunk dataset to be queried. + index (Index): Vector index populated with chunk database indices. + block (dict): Range information containing start/end indices for querying GPT chunk dataset. + """ + + n_chunks_per_sample = query_dataset.n_chunks_per_sample + + # Sample map. + sample_ids = sorted( + list(set(chunk_id // n_chunks_per_sample for chunk_id in range(*block["range"]))) + ) + sample_map = {} + for i in sample_ids: + sample = query_dataset.sample_dataset[i] + sample_map[i] = {"dataset_idx": sample["dataset_id"], "doc_ids": sample["document_ids"]} + + # Embed block. + embeddings = embed_block(config, query_dataset, block) + + # Query embeddings. + _, filtered_neighbor_ids = query_embedding_block( + config, db_dataset, index, embeddings, block["range"], sample_map, n_chunks_per_sample + ) + + if config.retro_task_validate is None: + # Save neighbors. + log_retro_rank_0("save neighbors.") + retro_makedir(config, os.path.dirname(block["path"])) + f = h5py.File(block["path"], "w") + f.create_dataset("neighbors", data=filtered_neighbor_ids) + f.close() + + else: + # Validate neighbors. + with h5py.File(block["path"]) as f: + existing_neighbor_ids = np.copy(f["neighbors"]) + assert np.array_equal(existing_neighbor_ids, filtered_neighbor_ids) + + +def query_dataset_neighbors( + config: RetroPreprocessingConfig, + db_dataset: DBDataset, + query_dataset: GPTChunkDataset, + num_active_chunks: int, + prefix: str, + neighbor_dir: str, + index: Index, +) -> None: + """Query neighbors of each chunk within a dataset. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + db_dataset (DBDataset): Dataset containing chunk database entries. + query_dataset (GPTChunkDataset): GPT chunk dataset to be queried. + num_active_chunks (int): The 'active' chunks are the subset of the GPT chunk dataset that aren't being queried. This argument is used when validating the correctness of a subset of the GPT chunk dataset. + prefix (str): Extra string for logging progress. + neighbor_dir (str): File path to directory for saving neighbor IDs. + index (Index): Vector index populated with chunk database indices. + """ + + def validate(f: h5py.File) -> None: + """Validation method for validating saved neighbor IDs. + + Args: + f (h5py.File): File containing save neighbor IDs. + """ + assert ( + f["neighbors"].shape[1] == config.retro_query_num_neighbors_save + ), "neighbors.shape == %s; num_neighbors_target == %d." % ( + str(f["neighbors"].shape), + config.retro_num_neighbors_target, + ) + + if config.retro_task_validate is None: + retro_makedir(config, neighbor_dir) + blocks = get_blocks_by_rank( + neighbor_dir, num_active_chunks, config.retro_block_size, validate=validate + ) + active_blocks = blocks.missing + else: + blocks = get_blocks_by_rank( + neighbor_dir, + num_active_chunks, + config.retro_block_size, + validate=validate, + sample=config.retro_task_validate, + ) + assert blocks.n_missing_world == 0 + active_blocks = blocks.existing + + # Query each block. + for block_index, block in enumerate(active_blocks): + + if block is not None: + + # Progress. + log_retro_rank_0( + "%squery '%s' block %d / %d ... %s ... mem %.3f gb, %.1f%%." + % ( + "" if config.retro_task_validate is None else "[validate] ", + prefix, + block_index, + len(active_blocks), + os.path.basename(block["path"]), + psutil.virtual_memory()[3] / 1024**3, + psutil.virtual_memory()[2], + ) + ) + + # Query block neighbors. + query_block_neighbors(config, db_dataset, query_dataset, index, block) + + # Synchronize progress across all ranks. (for easier observation) + log_retro_rank_0(" > waiting for other ranks to finish block.") + torch.distributed.barrier() + + +def query_neighbors(config: RetroPreprocessingConfig) -> None: + """Query pretraining datasets (train & valid). + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + """ + + # Num threads. + faiss.omp_set_num_threads(64) + + # Load chunk db dataset. + log_retro_rank_0("load chunk db dataset.") + db_dataset = get_db_merged_train_dataset( + project_dir=config.retro_project_dir, + chunk_length=config.retro_gpt_chunk_length, + eod_token_id=config.retro_tokenizers.gpt.eod, + ) + db_dataset.load_doc_tuples() + + # Load index. + log_retro_rank_0(" > get index.") + index = get_index(config) + + # Query each (i.e., train, valid, test) dataset. + log_retro_rank_0(" > query.") + for prefix, info in vars(config.retro_gpt_chunk_datasets).items(): + if info is None: + continue + log_retro_rank_0( + " > query '%s' dataset ... %d samples." % (prefix, info["num_active_chunks"]) + ) + query_dataset_neighbors( + config, + db_dataset, + info["dataset"], + info["num_active_chunks"], + prefix, + info["neighbor_dir"], + index, + ) diff --git a/megatron/core/datasets/retro/query/retro_dataset.py b/megatron/core/datasets/retro/query/retro_dataset.py new file mode 100644 index 0000000000..6c3b9ae60c --- /dev/null +++ b/megatron/core/datasets/retro/query/retro_dataset.py @@ -0,0 +1,238 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +A RetroDataset wraps both: + + - A GPTDataset (which is nested as GPTChunkDataset -> MultiSplitGPTDataset -> + GPTDataset). + - Neighbor IDs of chunks in the chunk database, that were saved during + preprocessing. + +Both the GPT sample data and the neighbor IDs are returned within a sample from +this dataset. +""" + +import os +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch + +from megatron.core.datasets.retro.db.dataset import DBDataset +from megatron.core.datasets.retro.db.utils import get_merged_train_dataset as get_db_dataset +from megatron.core.datasets.retro.external_libs import h5py +from megatron.core.datasets.retro.utils import BlockPathMap, log_retro_rank_0 +from megatron.core.models.retro import RetroConfig + +from .gpt_chunk_dataset import GPTChunkDataset, build_gpt_chunk_datasets_from_gpt_datasets +from .utils import get_query_dir + + +class RetroDataset(torch.utils.data.Dataset): + """Dataset of retro samples. + + Each sample contains the original GPT sample, along with the token IDs + of each neighbor of each chunk within the sequence. Neighbor array has + shape (num_chunks_per_sample, num_neighbors, num_retrieved_tokens). + + ** Note: chunk dataset wraps original GPT dataset (see gpt_chunk_dataset.py). + + Args: + num_queried_samples (int): Total number of queried samples. + num_neighbors (int): Total number of saved neighbors. + num_retrieved_chunks (int): Number of retrieved chunks (e.g., 2 for neighbor + continuation). + block_size (int): Number of neighbor entries per file. + db_dataset (DBDataset): Chunk database used for retrieval. + chunk_dataset (GPTChunkDataset): GPT chunk dataset, which is a wrapper around a standard GPT dataset that breaks each sample into chunks. + neighbor_path_map (BlockPathMap): Mapping of neighbor ID to file path. + """ + + def __init__( + self, + num_queried_samples: int, + num_neighbors: int, + num_retrieved_chunks: int, + block_size: int, + db_dataset: DBDataset, + chunk_dataset: GPTChunkDataset, + neighbor_path_map: BlockPathMap, + ): + super().__init__() + + self.num_queried_samples = num_queried_samples + self.num_neighbors = num_neighbors + self.num_retrieved_chunks = num_retrieved_chunks + self.block_size = block_size + self.db_dataset = db_dataset + self.chunk_dataset = chunk_dataset + self.neighbor_path_map = neighbor_path_map + + def __len__(self) -> int: + """Dataset length. + + Returns: + Number of samples in dataset. + """ + return len(self.chunk_dataset.sample_dataset) + + def __getitem__(self, sample_idx: int) -> dict: + """Get dataset sample. + + Args: + sample_idx (int): Index of sample in dataset. + + Returns: + A dict consisting of GPT sample (attribute 'text') and corresponding neighbor chunk IDs ('neighbor_chunks', for indexing chunk database) and neighbor token IDs (corresponding chunk database GPT tokens). + """ + n_chunks_per_sample = self.chunk_dataset.n_chunks_per_sample + + # Wrap sample idx around number of queried samples. + sample_idx = sample_idx % self.num_queried_samples + + # Get standard sample. + sample = self.chunk_dataset.sample_dataset[sample_idx] + + # Sample idx to chunk idxs. + chunk_idxs = list( + range(sample_idx * n_chunks_per_sample, (sample_idx + 1) * n_chunks_per_sample) + ) + + # Collect retrieved tokens. + all_retrieved_chunk_ids = [] + all_retrieved_token_ids = [] + for chunk_idx in chunk_idxs: + + # Neighbor chunk ids. + neighbor_path = self.neighbor_path_map[chunk_idx] + with h5py.File(neighbor_path, "r") as f: + neighbor_chunk_ids = f["neighbors"][ + chunk_idx % self.block_size, : self.num_neighbors + ].tolist() + + # Retrieved (neighbor + continuation) token ids. + retrieved_chunk_ids = [] + retrieved_token_ids = [] + for neighbor_chunk_id in neighbor_chunk_ids: + current_chunk_ids = [ + i % len(self.db_dataset) + for i in range(neighbor_chunk_id, neighbor_chunk_id + self.num_retrieved_chunks) + ] + current_token_ids = [self.db_dataset[ci]["text"] for ci in current_chunk_ids] + retrieved_chunk_ids.append(current_chunk_ids) + retrieved_token_ids.append(current_token_ids) + + # Collect retrieved tokens. + all_retrieved_chunk_ids.append(retrieved_chunk_ids) + all_retrieved_token_ids.append(retrieved_token_ids) + + # Reshape retrieved tokens. + all_retrieved_chunk_ids = np.array(all_retrieved_chunk_ids).reshape( + (n_chunks_per_sample, self.num_neighbors, -1) + ) + all_retrieved_token_ids = np.array(all_retrieved_token_ids).reshape( + (n_chunks_per_sample, self.num_neighbors, -1) + ) + + # Sample. + sample: Dict[str, np.ndarray] = { + **sample, + "neighbor_chunks": all_retrieved_chunk_ids, + "neighbor_tokens": all_retrieved_token_ids, + } + + return sample + + +def get_retro_datasets( + config: RetroConfig, gpt_datasets: dict, sample_length: int, eod_token_id: int +) -> Tuple[Optional[RetroDataset], Optional[RetroDataset], Optional[RetroDataset]]: + """Get train, valid, test retro datasets. + + Args: + config (RetroConfig): Retro preprocessing config. + gpt_datasets (dict): Mapping of data split key ('train', 'valid', or 'test') to the original sequence-length GPT dataset (i.e., not the chunk dataset). + sample_length (int): Alias to `sequence_length`. + eod_token_id (int): GPT EOD token ID. + + Returns: + A tuple of 'train', 'valid', and 'test' `RetroDataset`s. + """ + + # DB dataset. + db_dataset = get_db_dataset( + project_dir=config.retro_project_dir, + chunk_length=config.retro_chunk_length, + eod_token_id=eod_token_id, + ) + + # GPT chunk datasets. + chunk_ds_info_map = build_gpt_chunk_datasets_from_gpt_datasets( + project_dir=config.retro_project_dir, + gpt_datasets=gpt_datasets, + sample_length=sample_length, + chunk_length=config.retro_chunk_length, + ) + + # Retro datasets. + retro_dataset_map: Dict[str, Optional[RetroDataset]] = {} + query_dir = get_query_dir(config.retro_project_dir) + for data_key, chunk_ds_info in chunk_ds_info_map.items(): + + # Skip unused datasets. + if chunk_ds_info is None: + retro_dataset_map[data_key] = None + continue + + # For consistency with preprocessing, the neighbor_dir is overwritten + # (from its setting in `build_gpt_chunk_datasets_from_gpt_datasets()` + # above). This is one piece -- along with setting data_path and + # train_samples from config.json -- of ensuring consistency between + # preprocessing and pretraining. + chunk_dataset = chunk_ds_info["dataset"] + chunk_ds_info["neighbor_dir"] = os.path.join( + query_dir, config.retro_neighbor_dirs[data_key] + ) + neighbor_dir = chunk_ds_info["neighbor_dir"] + neighbor_path_map = BlockPathMap.from_dir( + dir=neighbor_dir, block_size=config.retro_block_size + ) + + # Verify num chunks. + n_active_chunks = chunk_ds_info["num_active_chunks"] + n_neighbor_chunks = neighbor_path_map.max_idx + + if not os.path.isdir(neighbor_dir): + if torch.distributed.get_rank() == 0: + raise Exception( + "neighbor directory '%s' not found; please " + "compare --train-samples, --seq-length, --seed, " + "--eval-iters, and --eval-interval, with " + "retro preprocessing args." % neighbor_dir + ) + torch.distributed.barrier() + exit() + + if config.retro_verify_neighbor_count and n_active_chunks != n_neighbor_chunks: + if torch.distributed.get_rank() == 0: + log_retro_rank_0("neighbor_dir : %s" % neighbor_dir) + log_retro_rank_0("neighbor_path_map : %s" % neighbor_path_map) + raise Exception( + "num sampled chunks (%d) != num neighbor chunks " + "(%d); did you complete querying the entire " + "pretraining dataset?" % (n_active_chunks, n_neighbor_chunks) + ) + torch.distributed.barrier() + exit() + + # Retro dataset. + retro_dataset_map[data_key] = RetroDataset( + num_queried_samples=gpt_datasets[data_key][1], + num_neighbors=config.retro_num_neighbors, + num_retrieved_chunks=config.retro_num_retrieved_chunks, + block_size=config.retro_block_size, + db_dataset=db_dataset, + chunk_dataset=chunk_dataset, + neighbor_path_map=neighbor_path_map, + ) + + return (retro_dataset_map["train"], retro_dataset_map["valid"], retro_dataset_map["test"]) diff --git a/megatron/core/datasets/retro/query/utils.py b/megatron/core/datasets/retro/query/utils.py new file mode 100644 index 0000000000..b4e0c67009 --- /dev/null +++ b/megatron/core/datasets/retro/query/utils.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Utilities for querying the pretraining dataset.""" + +import os + +from megatron.core.datasets.megatron_dataset import MegatronDataset + + +def get_query_dir(project_dir: str) -> str: + """Get root directory of all saved query data. + + Args: + project_dir (str): Retro project dir. + + Returns: + Path to query sub-directory in Retro project. + """ + return os.path.join(project_dir, "query") + + +def get_neighbor_dir(project_dir: str, key: str, dataset: MegatronDataset) -> str: + """Get directory containing neighbor IDs for a dataset (i.e., train, valid, or test). + + Args: + project_dir (str): Retro project dir. + key (str): Dataset split key; 'train', 'valid', or 'test'. + dataset (MegatronDataset): Dataset containing unique hash for finding corresponding neighbors. + + Returns: + Path to directory containing this dataset's neighbors within Retro project. + """ + return os.path.join( + get_query_dir(project_dir), os.path.basename(f"{key}_{dataset.unique_description_hash}") + ) diff --git a/megatron/core/datasets/retro/utils.py b/megatron/core/datasets/retro/utils.py new file mode 100644 index 0000000000..31c0be14c8 --- /dev/null +++ b/megatron/core/datasets/retro/utils.py @@ -0,0 +1,349 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +"""Utilities for Retro preprocessing.""" + +import glob +import logging +import os +from collections import defaultdict +from types import SimpleNamespace +from typing import Any, Callable, Dict, List, Optional + +import numpy as np +import torch +from tqdm import tqdm + +from megatron.core import parallel_state +from megatron.core.datasets.retro.config import RetroPreprocessingConfig +from megatron.core.datasets.retro.query.multi_split_gpt_dataset import ( + MultiSplitGPTDataset, + MultiSplitGPTDatasetConfig, +) +from megatron.core.utils import log_single_rank + +from .external_libs import h5py + +logger = logging.getLogger(__name__) + + +def log_retro_rank_0(message: str) -> None: + """Log on rank 0. + + Args: + message (str): Message to log. + """ + log_single_rank(logger, logging.INFO, "[RETRO] " + message) + + +def retro_makedir(config: RetroPreprocessingConfig, path: str) -> None: + """Make a directory, conditional on not being in validation mode. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + path (str): Path to directory. + """ + if config.retro_task_validate is None: + os.makedirs(path, exist_ok=True) + + +def extract_data_config(config: RetroPreprocessingConfig) -> MultiSplitGPTDatasetConfig: + """Extract data config from dataset. + + Args: + config (RetroPreprocessingConfig): Retro preprocessing config. + + Returns: + The config object used to build the dataset. + """ + return config.retro_gpt_chunk_datasets.train["dataset"].sample_dataset.config + + +def get_num_chunks_per_sample(sample_length: int, chunk_length: int) -> int: + """Compute seq_length // chunk_length. + + Args: + sample_length (int): Alias of `sequence_length`. + chunk_length (int): Retro chunk length (e.g., 64). + + Returns: + Number of chunks per sample (i.e., `sequence_length` / `chunk_length`). + """ + assert sample_length % chunk_length == 0 + return sample_length // chunk_length + + +class GPTToTextDataset(torch.utils.data.Dataset): + """Dataset to convert GPT tokens to text. + + Args: + gpt_dataset (MultiSplitGPTDataset): GPT dataset, which outputs GPT token samples. + gpt_tokenizer (Any): GPT tokenizer. + """ + + def __init__(self, gpt_dataset: MultiSplitGPTDataset, gpt_tokenizer: Any): + + super().__init__() + + self.gpt_dataset = gpt_dataset + self.gpt_tokenizer = gpt_tokenizer + + def __len__(self) -> int: + """Dataset length. + + Returns: + Number of samples in the dataset. + """ + return len(self.gpt_dataset) + + def __getitem__(self, idx: int) -> dict: + """Get dataset sample. + + Args: + idx (int): Index of sample. + + Returns: + A dict containing attribute 'text' of type string. + """ + gpt_token_ids = self.gpt_dataset[idx]["text"].tolist() + text = self.gpt_tokenizer.detokenize(gpt_token_ids) + return {"text": text} + + +def get_blocks( + dirname: str, n_samples: int, block_size: int, validate: Callable = None +) -> SimpleNamespace: + """Divide range [0, num_samples) to sequence of block ranges. + + This is a core method within the concept of block processing. The idea + is to divide a range (size n_samples) into a sequence of blocks. Each + block corresponds to a file within 'dirname' with name + '{start_idx}-{end_idx}.hdf5'. This method checks for the existence of + these files, and returns two lists, one for existing blocks and one for + missing blocks. + + Args: + dirname (str): Path to directory containing block files. + n_samples (int): Ideal number of samples. The total number of saved block data is <=n_samples. + block_size (int): Max number of samples per block file (e.g., 100000). + validate (Callable): Method for validating each block file during load. + + Returns: + A namespace consisting of 2 lists: existing blocks, and missing blocks. The total number of samples between the existing and missing blocks should equal n_samples above. + """ + + assert os.path.isdir(dirname), "missing directory '%s.'" % dirname + + # Block ranges. + block_start_idxs = list(range(0, n_samples, block_size)) + block_end_idxs = [min(n_samples, i + block_size) for i in block_start_idxs] + block_ranges = list(zip(block_start_idxs, block_end_idxs)) + + # All block files (existing + missing). + n_digits = int(np.ceil(np.log(n_samples) / np.log(10)) + 1) + all_blocks = [ + { + "range": r, + "path": os.path.join( + dirname, "%s-%s.hdf5" % tuple([str(i).zfill(n_digits) for i in r]) + ), + } + for r in block_ranges + ] + all_block_path_set = set(block["path"] for block in all_blocks) + + # Validate function. + validate = (lambda f: None) if validate is None else validate + + # Delete corrupt files. + if torch.distributed.get_rank() == 0: + existing_block_paths = [ + block["path"] for block in all_blocks if os.path.exists(block["path"]) + ] + for index, path in enumerate(tqdm(existing_block_paths, "validating block.")): + + assert path in all_block_path_set, "unexpected filename, '%s'." % path + + try: + f = h5py.File(path, "r") + except Exception: + os.remove(path) + continue + + try: + validate(f) + except Exception: + os.remove(path) + finally: + f.close() + + # Wait for files to be deleted. + torch.distributed.barrier() + + # Collect blocks. + blocks = SimpleNamespace( + existing=[b for b in all_blocks if os.path.exists(b["path"])], + missing=[b for b in all_blocks if not os.path.exists(b["path"])], + ) + + return blocks + + +def get_blocks_by_rank( + dirname: str, + n_samples: int, + block_size: int, + validate: Callable = None, + sample: Optional[float] = None, +) -> SimpleNamespace: + """Divide existing and missing blocks evenly across all ranks. + + See 'get_blocks()' above for description. The returned lists of existing and + missing blocks are split evenly across ranks via interleaving. This way, + each rank has a roughly equal number of blocks to process for a + downstream operation. + + Args: + dirname (str): Path to directory containing block files. + n_samples (int): Ideal number of samples. The total number of saved block data is <=n_samples. + block_size (int): Max number of samples per block file (e.g., 100000). + validate (Callable): Method for validating each block file during load. + sample (Optional[float]): If provided, sample a random subset of the blocks. Used for validating preprocessing correctness. + + Returns: + A namespace consisting of 2 lists: existing blocks, and missing blocks. Each of these two lists is potentially a sub-sample of the total set of existing and missing blocks, depending on whether sampling is used. Additionally, the attributes n_existing_world and n_missing_world are the total number of existing and missing blocks, independent of samples. Therefore, (n_existing_world + n_missing_world) * block_size == n_samples. + """ + + # Get world blocks. + blocks = get_blocks(dirname, n_samples, block_size, validate) + + # This rank's existing and missing files. + data_parallel_rank = parallel_state.get_data_parallel_rank() + data_parallel_world_size = parallel_state.get_data_parallel_world_size() + rank_existing_blocks = blocks.existing[ + data_parallel_rank : len(blocks.existing) : data_parallel_world_size + ] + rank_missing_blocks = blocks.missing[ + data_parallel_rank : len(blocks.missing) : data_parallel_world_size + ] + + # Extend rank's existing and missing blocks (with None) such that all ranks + # have equal length lists. This allows for easier tracking of global progress. + def get_world_max(n: int) -> int: + """Get max value across ranks. + + Args: + n (int): Value on this rank. + + Returns: + Max value across all ranks. + """ + n_tensor = torch.cuda.LongTensor([n]) + torch.distributed.all_reduce(n_tensor, op=torch.distributed.ReduceOp.MAX) + return n_tensor.item() + + max_n_existing = get_world_max(len(rank_existing_blocks)) + max_n_missing = get_world_max(len(rank_missing_blocks)) + + rank_existing_blocks += [None] * (max_n_existing - len(rank_existing_blocks)) + rank_missing_blocks += [None] * (max_n_missing - len(rank_missing_blocks)) + + # Collect blocks. + blocks = SimpleNamespace( + n_existing_world=len(blocks.existing), + n_missing_world=len(blocks.missing), + existing=rank_existing_blocks, + missing=rank_missing_blocks, + ) + + if sample is not None: + # Sample existing and missing blocks evenly across all ranks. The + # returned lists of blocks are randomly sampled (without replacement) + # to yield `sample * len(blocks)` number of blocks. + + # Randomly sample blocks. + def sample_blocks(_blocks: List[Optional[Dict]]) -> List[Optional[Dict]]: + """Sample a random subset of all blocks. + + Args: + _blocks (List[Optional[Dict]]): List of all blocks. + + Returns: + A random subset of the blocks. + """ + n_blocks_sample = int(np.ceil(sample * len(_blocks))) + sampled_blocks: List[Optional[Dict]] = [b for b in _blocks if b is not None] + + np.random.seed(None) + np.random.shuffle(sampled_blocks) + + sampled_blocks = sampled_blocks[:n_blocks_sample] + sampled_blocks += [None] * (n_blocks_sample - len(sampled_blocks)) + + return sampled_blocks + + blocks.existing = sample_blocks(blocks.existing) + blocks.missing = sample_blocks(blocks.missing) + + return blocks + + +class BlockPathMap: + """Map an index to its containing block path. + + The common use for this class is to have a directory of files containing + blocks of processed data, of uniform block size (e.g., 100k samples per + file). Each file must follow a naming convention of 'startIdx-endIdx.[ext]', + where 'endIdx' minus 'startIdx' must equal the block size, with the possible + exception of the final block. Given an input index, this class maps the + index to the containing block file. + + Args: + block_paths (List[str]): List of paths to saved block files. + block_size (int): Max number of samples per block file (e.g., 100000). + """ + + @classmethod + def from_dir(cls, dir: str, block_size: int, ext: str = "hdf5") -> Any: + """Get list of block files, and create map. + + Args: + dir (str): Path to directory containing saved block files. + block_size (int): Max number of samples per block file (e.g., 100000). + ext (str): Block file extension (e.g., 'hdf5'). + + Returns: + A mapping of sample index to block file path. + """ + assert os.path.isdir(dir), f"directory not found, '{dir}'." + return cls(sorted(glob.glob(dir + f"/*.{ext}")), block_size) + + def __init__(self, block_paths: List[str], block_size: int): + self.max_idx = 0 + self.block_path_map = {} + for block_path in block_paths: + name = os.path.splitext(os.path.basename(block_path))[0] + start_idx, end_idx = [int(i) for i in name.split("-")] + self.block_path_map[start_idx] = block_path + self.max_idx = max(self.max_idx, end_idx) + self.block_size = block_size + + def __str__(self) -> str: + """Stringify the mapping. + + Returns: + A string representation of this block path map. + """ + return "%d paths" % len(self.block_path_map) + + def __getitem__(self, idx: int) -> str: + """Get block path from index. + + Args: + idx (int): Index of sample. + + Returns: + The path to the block file containing the sample index. + """ + block_start_idx = self.block_size * (idx // self.block_size) + block_path = self.block_path_map[block_start_idx] + return block_path diff --git a/megatron/core/datasets/t5_dataset.py b/megatron/core/datasets/t5_dataset.py new file mode 100644 index 0000000000..f356426ed2 --- /dev/null +++ b/megatron/core/datasets/t5_dataset.py @@ -0,0 +1,331 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import os +from collections import deque +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Union + +import numpy +import torch +from packaging.version import Version as PkgVersion + +from megatron.core.datasets.indexed_dataset import IndexedDataset +from megatron.core.datasets.masked_dataset import ( + MaskedWordPieceDataset, + MaskedWordPieceDatasetConfig, +) +from megatron.core.datasets.utils import Split +from megatron.core.utils import get_te_version + + +@dataclass +class T5MaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig): + """Configuration object for Megatron Core T5 WordPiece datasets + + NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines + a number of special sentinel tokens used during sampling. The assert in __post_init__ serves to + preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core. + """ + + sequence_length_encoder: Optional[int] = field(init=False, default=None) + """A sequence_length alias and the sequence length for the encoder""" + + sequence_length_decoder: int = None + """The sequence length for the decoder""" + + def __post_init__(self) -> None: + """Do asserts and set fields post init""" + super().__post_init__() + + self.sequence_length_encoder = self.sequence_length + + assert self.sequence_length_encoder is not None + assert self.sequence_length_decoder is not None + + assert len(self.tokenizer.additional_special_tokens_ids) > 0 + + +class T5MaskedWordPieceDataset(MaskedWordPieceDataset): + """The T5 dataset that assumes WordPiece tokenization + + Args: + indexed_dataset (IndexedDataset): The IndexedDataset around + which to build the MegatronDataset + + dataset_path (str): The real path on disk to the dataset, for bookkeeping + + indexed_indices (numpy.ndarray): The set of the documents indices to expose + + num_samples (Optional[int]): The number of samples to draw from the indexed + dataset. When None, build as many samples as correspond to one epoch. + + index_split (Split): The indexed_indices Split + + config (T5MaskedWordPieceDatasetConfig): The config + """ + + def __init__( + self, + indexed_dataset: IndexedDataset, + dataset_path: str, + indexed_indices: numpy.ndarray, + num_samples: Optional[int], + index_split: Split, + config: T5MaskedWordPieceDatasetConfig, + ) -> None: + super().__init__( + indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config + ) + + self.token_lookup = list(self.config.tokenizer.inv_vocab.keys()) + # Account for the single and single token ids + self.sample_index = self._build_sample_index(self.config.sequence_length - 2, 1) + + @staticmethod + def _key_config_attributes() -> List[str]: + """Inherited method implementation + + Returns: + List[str]: The key config attributes + """ + return super( + T5MaskedWordPieceDataset, T5MaskedWordPieceDataset + )._key_config_attributes() + ["sequence_length_decoder"] + + @staticmethod + def _build_b1ss_attention_mask( + source_block: torch.tensor, target_block: torch.tensor, make_history_mask: bool = False + ) -> torch.tensor: + """Build an attention-mask having shape (bs, 1, q_len, kv_len) + from source_block and target_block + + Args: + source_block (torch.tensor): A 2-D array of tokens (bs, q_len) + target_block (torch.tensor): A 2-D array of tokens (bs, kv_len) + make_history_mask (bool): Whether to turn mask into causal mask + + Returns: + torch.tensor: The 4-D attention mask (bs, 1, q_len, kv_len) + """ + batch_size = source_block.shape[0] + attention_mask = [] + for i in range(batch_size): + source_sample = source_block[i] + target_sample = target_block[i] + mask = (target_sample[None, :] >= 1) * (source_sample[:, None] >= 1) + if make_history_mask: + arange = numpy.arange(source_sample.shape[0]) + history_mask = arange[None,] <= arange[:, None] + history_mask = torch.tensor(history_mask).to(mask.device) + mask = mask * history_mask + mask = ~(mask) # flip True to False + attention_mask.append(mask) + attention_mask = torch.stack(attention_mask) + attention_mask = attention_mask.unsqueeze(1) + return attention_mask + + @staticmethod + def config_attention_mask( + encoder_tokens: torch.tensor, + decoder_tokens: torch.tensor, + encoder_mask: torch.tensor, + decoder_mask: torch.tensor, + use_local: bool = False, + test_te_version: str = None, + ) -> torch.tensor: + """Config attention-mask for encoder_mask, decoder_mask, encoder_decoder_mask + conditioned on transformer-implementation (e.g. TE vs local), TE versions, + and TE backends + + Args: + encoder_tokens (torch.tensor): A 2-D array of tokens (bs, kv_len) + decoder_tokens (torch.tensor): A 2-D array of tokens (bs, q_len) + encoder_mask (torch.tensor): A 2-D array of tokens (bs, kv_len) + decoder_mask (torch.tensor): A 2-D array of tokens (bs, q_len) + use_local (bool): Whether the current T5 model uses local (vs TE) + transformer implmentation + + Returns: + Configured encoder_mask, decoder_mask, encoder_decoder_mask + torch.tensor: configured encoder attention mask + torch.tensor: configured decoder attention mask + torch.tensor: configured encoder-decoder attention mask + """ + # If using local transformer implementation (not transformer_engine): + # re-organize all attention masks, because local and transformer_engine + # backbones use different masks shapes. E.g.: + # (local: b1ss - transformer_engine: b11s) + if use_local: + encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask( + encoder_tokens, encoder_tokens + ) + decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask( + decoder_tokens, decoder_tokens, make_history_mask=True + ) + encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask( + decoder_tokens, encoder_tokens + ) + + else: + # If using transformer_engine transformer implementation: + # 1. For TE version >= 1.10, across all 3 backends, + # The padding mask is configued as + # [bs, 1, 1, seq_len] for self-attention and + # ([bs, 1, 1, q_len], [bs, 1, 1, kv_len]) for cross-attention + # 2. For TE version >=1.7 and <1.10, when using Non-fused backend, + # The padding mask is configued as + # [bs, 1, q_len, kv_len] for both self-attention and for cross-attention + # 3. For TE version <1.7, only support Non-fused backend + # The padding mask is configued as + # [bs, 1, q_len, kv_len] for both self-attention and for cross-attention + + # Process for Flash/Fused + encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(1) + decoder_mask = decoder_mask.unsqueeze(1).unsqueeze(1) + encoder_decoder_mask = (decoder_mask, encoder_mask) + # set decoder_mask to None because decoder uses AttnMaskType.causal + decoder_mask = None + + # get TE version, using test TE version if not None + if test_te_version is not None: + te_version = PkgVersion(test_te_version) + else: + te_version = get_te_version() + + # Check for older TE version than 1.10, adjust attention mask accordingly + flash_attention_enabled = os.getenv('NVTE_FLASH_ATTN') == '1' + fused_attention_enabled = os.getenv('NVTE_FUSED_ATTN') == '1' + if (te_version < PkgVersion("1.10.0")) and (te_version >= PkgVersion("1.7.0")): + if not (flash_attention_enabled) and not (fused_attention_enabled): + encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask( + encoder_tokens, encoder_tokens + ) + encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask( + decoder_tokens, encoder_tokens + ) + else: + pass + elif te_version < PkgVersion("1.7.0"): + if not (flash_attention_enabled) and not (fused_attention_enabled): + encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask( + encoder_tokens, encoder_tokens + ) + encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask( + decoder_tokens, encoder_tokens + ) + else: + assert not flash_attention_enabled and not fused_attention_enabled, ( + "Flash and fused attention is not supported with transformer " + "engine version < 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0" + "or upgrade transformer engine >= 1.7" + ) + return encoder_mask, decoder_mask, encoder_decoder_mask + + def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: + """Abstract method implementation + + Args: + idx (int): The index into the dataset + + Returns: + Dict[str, Union[int, numpy.ndarray]]: The + """ + idx_beg, idx_end, target_sequence_length = self.sample_index[idx] + sample = [self.dataset[i] for i in range(idx_beg, idx_end)] + + numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32) + + assert target_sequence_length <= self.config.sequence_length + + # Flatten the sample into a list of tokens + tokens = [token for sentence in sample for token in sentence] + + # Truncate the list of tokens to a desired length + truncated = len(tokens) > target_sequence_length + tokens = tokens[:target_sequence_length] + + # Masking + (tokens, _, _, _, masked_spans) = self._create_masked_lm_predictions( + tokens, target_sequence_length, numpy_random_state + ) + + # Prepare the encoder input and decoder input and output + sentinels = deque(self.config.tokenizer.additional_special_tokens_ids) + encoder_input = [] + decoder_input = [self.config.tokenizer.bos] + decoder_output = [] + idx_beg = 0 + for indices, labels in masked_spans: + sentinel = sentinels.popleft() + + # set the end index + idx_end = indices[0] + + encoder_input.extend(tokens[idx_beg:idx_end]) + encoder_input.append(sentinel) + + decoder_input.append(sentinel) + decoder_input.extend(labels) + + decoder_output.append(sentinel) + decoder_output.extend(labels) + + # set the start index + idx_beg = indices[-1] + 1 + + encoder_input.extend(tokens[idx_beg:]) + decoder_output.append(self.config.tokenizer.eos) + + # Pad the sequences and convert to NumPy + length_toks_encoder = len(encoder_input) + length_toks_decoder = len(decoder_input) + length_pads_encoder = self.config.sequence_length_encoder - length_toks_encoder + length_pads_decoder = self.config.sequence_length_decoder - length_toks_decoder + assert length_pads_encoder >= 0 + assert length_pads_decoder >= 0 + + encoder_input = numpy.array(encoder_input, dtype=numpy.int64) + encoder_input = numpy.pad( + encoder_input, (0, length_pads_encoder), constant_values=self.config.tokenizer.pad + ) + + decoder_input = numpy.array(decoder_input, dtype=numpy.int64) + decoder_input = numpy.pad( + decoder_input, (0, length_pads_decoder), constant_values=self.config.tokenizer.pad + ) + + # Create attention and history masks + mask_encoder = numpy.array([1] * length_toks_encoder + [0] * length_pads_encoder) + mask_decoder = numpy.array([1] * length_toks_decoder + [0] * length_pads_decoder) + mask_encoder_decoder = None + + # Mask the labels + decoder_output = numpy.array(decoder_output, dtype=numpy.int64) + decoder_output = numpy.pad(decoder_output, (0, length_pads_decoder), constant_values=-1) + + # Get the loss mask + loss_mask = numpy.zeros(self.config.sequence_length_decoder, dtype=numpy.int64) + loss_mask[:length_toks_decoder] = 1 + + return { + "text_enc": encoder_input, + "text_dec": decoder_input, + "labels": decoder_output, + "loss_mask": loss_mask, + "truncated": int(truncated), + "enc_mask": mask_encoder, + "dec_mask": mask_decoder, + } + + def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> int: + """Abstract method implementation + + 100% of the time, replace the token id with mask token id. + + Args: + numpy_random_state (RandomState): The NumPy random state + + Returns: + int: The mask token id + """ + return self.config.tokenizer.mask diff --git a/megatron/core/datasets/utils.py b/megatron/core/datasets/utils.py index 8a3279b5f4..8d887d4a4a 100644 --- a/megatron/core/datasets/utils.py +++ b/megatron/core/datasets/utils.py @@ -2,11 +2,13 @@ import logging from enum import Enum -from typing import List +from typing import List, Optional, Tuple import numpy import torch +from ..utils import log_single_rank + logger = logging.getLogger(__name__) @@ -17,8 +19,7 @@ class Split(Enum): def compile_helpers(): - """Compile C++ helper functions at runtime. Make sure this is invoked on a single process. - """ + """Compile C++ helper functions at runtime. Make sure this is invoked on a single process.""" import os import subprocess @@ -30,21 +31,6 @@ def compile_helpers(): sys.exit(1) -def log_single_rank(logger: logging.Logger, *args, rank=0, **kwargs): - """If torch distributed is initialized, log only on rank - - Args: - logger (logging.Logger): The logger to write the logs - - rank (int, optional): The rank to write on. Defaults to 0. - """ - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == rank: - logger.log(*args, **kwargs) - else: - logger.log(*args, **kwargs) - - def normalize(weights: List[float]) -> List[float]: """Do non-exponentiated normalization @@ -58,3 +44,44 @@ def normalize(weights: List[float]) -> List[float]: w_sum = numpy.sum(w) w = (w / w_sum).tolist() return w + + +def get_blend_from_list( + blend: Optional[List[str]], +) -> Optional[Tuple[List[str], Optional[List[float]]]]: + """Get the megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig blend from the blend list + + Args: + blend (Optional[List[str]]): The blend list, which can be either (1) a list of prefixes, e.g. ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], or (2) a flattened, zipped list of weights and prefixes, e.g. ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"] + + Returns: + Optional[Tuple[List[str], Optional[List[float]]]]: The blend, consisting of a list of dataset prefixes and optionally a list of dataset weights, e.g. [["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], [30.0, 70.0]]. + """ + if blend is None: + return None + + if len(blend) % 2 == 1: + weight_per_dataset = None + raw_prefix_per_dataset = blend + else: + raw_weight_per_dataset, raw_prefix_per_dataset = zip( + *[(blend[i], blend[i + 1]) for i in range(0, len(blend), 2)] + ) + + weight_per_dataset = [] + for rwpd in raw_weight_per_dataset: + try: + weight = float(rwpd) + except ValueError: + weight = None + weight_per_dataset.append(weight) + + is_none = map(lambda _: _ is None, weight_per_dataset) + if any(is_none): + assert all(is_none) + weight_per_dataset = None + raw_prefix_per_dataset = blend + + prefix_per_dataset = [rppd.strip() for rppd in raw_prefix_per_dataset] + + return prefix_per_dataset, weight_per_dataset diff --git a/megatron/core/datasets/utils_s3.py b/megatron/core/datasets/utils_s3.py new file mode 100644 index 0000000000..61103b429d --- /dev/null +++ b/megatron/core/datasets/utils_s3.py @@ -0,0 +1,164 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import os +from typing import Any, Dict, NamedTuple, Protocol, Tuple + +import torch + +try: + import boto3 + import botocore.exceptions as exceptions +except ModuleNotFoundError: + pass + +S3_PREFIX = "s3://" + + +class S3Config(NamedTuple): + """Config when the data (.bin) file and the index (.idx) file are in S3 + + TODO: These parameters are few and can be consolidated with parameters specific to bin reader + classes - @jkamalu + + Attributes: + + path_to_idx_cache (str): The local directory where we will store the index (.idx) file + + bin_chunk_nbytes (int): If the number of bytes is too small, then we send a request to S3 at each call of the `read` method in _S3BinReader, which is slow, because each request has a fixed cost independent of the size of the byte range requested. If the number of bytes is too large, then we only rarely have to send requests to S3, but it takes a lot of time to complete the request when we do, which can block training. We've found that 256 * 1024 * 1024 (i.e., 256 MiB) has worked well (though we have not put that much effort into tuning it), so we default to it. + """ + + path_to_idx_cache: str + + bin_chunk_nbytes: int = 256 * 1024 * 1024 + + +class S3Client(Protocol): + """The protocol which all s3 clients should abide by""" + + def download_file(self, Bucket: str, Key: str, Filename: str) -> None: ... + + def upload_file(self, Filename: str, Bucket: str, Key: str) -> None: ... + + def head_object(self, Bucket: str, Key: str) -> Dict[str, Any]: ... + + def get_object(self, Bucket: str, Key: str, Range: str) -> Dict[str, Any]: ... + + def close(self) -> None: ... + + +def is_s3_path(path: str) -> bool: + """Ascertain whether a path is in S3 + + Args: + path (str): The path + + Returns: + bool: True if the path is in S3, False otherwise + """ + return path.startswith(S3_PREFIX) + + +def parse_s3_path(path: str) -> Tuple[str, str]: + """Parses the given S3 path returning correspsonding bucket and key. + + Args: + path (str): The S3 path + + Returns: + Tuple[str, str]: A (bucket, key) tuple + """ + assert is_s3_path(path) + parts = path.replace(S3_PREFIX, "").split("/") + bucket = parts[0] + if len(parts) > 1: + key = "/".join(parts[1:]) + assert S3_PREFIX + bucket + "/" + key == path + else: + key = "" + return bucket, key + + +def object_exists(client: S3Client, path: str) -> bool: + """Ascertain whether the object at the given S3 path exists in S3 + + Args: + client (S3Client): The S3 client + + path (str): The S3 path + + Raises: + botocore.exceptions.ClientError: The error code is 404 + + Returns: + bool: True if the object exists in S3, False otherwise + """ + parsed_s3_path = parse_s3_path(path) + try: + response = client.head_object(bucket=parsed_s3_path[0], key=parsed_s3_path[1]) + except exceptions.ClientError as e: + if e.response["Error"]["Code"] != "404": + raise e + return True + + +def _download_file(client: S3Client, s3_path: str, local_path: str) -> None: + """Download the object at the given S3 path to the given local file system path + + Args: + client (S3Client): The S3 client + + s3_path (str): The S3 source path + + local_path (str): The local destination path + """ + dirname = os.path.dirname(local_path) + os.makedirs(dirname, exist_ok=True) + parsed_s3_path = parse_s3_path(s3_path) + client.download_file(parsed_s3_path[0], parsed_s3_path[1], local_path) + + +def maybe_download_file(s3_path: str, local_path: str) -> None: + """Download the object at the given S3 path to the given local file system path + + In a distributed setting, downloading the S3 object proceeds in stages in order + to try to have the minimum number of processes download the object in order for + all the ranks to have access to the downloaded object. + + Args: + s3_path (str): The S3 source path + + local_path (str): The local destination path + """ + + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + local_rank = rank % torch.cuda.device_count() + else: + rank = 0 + local_rank = 0 + + s3_client = boto3.client("s3") + + if (not os.path.exists(local_path)) and (rank == 0): + _download_file(s3_client, s3_path, local_path) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + # If the `local_path` is in a file system that is not + # shared across all the ranks, then we assume it's in the + # host file system and each host needs to download the file. + if (not os.path.exists(local_path)) and (local_rank == 0): + _download_file(s3_client, s3_path, local_path) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + # If the `local_path` still does not exist, then we assume + # each rank is saving to a separate location. + if not os.path.exists(local_path): + _download_file(s3_client, s3_path, local_path) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + assert os.path.exists(local_path) diff --git a/megatron/core/dist_checkpointing/__init__.py b/megatron/core/dist_checkpointing/__init__.py index df08d7eaba..30ce2c82de 100644 --- a/megatron/core/dist_checkpointing/__init__.py +++ b/megatron/core/dist_checkpointing/__init__.py @@ -1,11 +1,12 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. from .core import check_is_distributed_checkpoint -from .mapping import LocalNonpersitentObject, ShardedTensor +from .mapping import LocalNonpersistentObject, ShardedObject, ShardedTensor from .serialization import ( load, load_common_state_dict, load_plain_tensors, load_tensors_metadata, + remove_sharded_tensors, save, ) diff --git a/megatron/core/dist_checkpointing/core.py b/megatron/core/dist_checkpointing/core.py index f20a0c3a2d..af6ebff6ec 100644 --- a/megatron/core/dist_checkpointing/core.py +++ b/megatron/core/dist_checkpointing/core.py @@ -1,5 +1,7 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +""" Module for managing distributed checkpoints metadata. """ + import json from dataclasses import asdict, dataclass from pathlib import Path @@ -9,12 +11,21 @@ class CheckpointingException(Exception): + """Base checkpointing related exception""" + pass @dataclass class CheckpointingConfig: - """ Documents backends used in the checkpoint. """ + """Documents backends used in the checkpoint. + + Checkpoint config keeps track of formats used for storing the sharded tensors + (sharded_backend) and other objects (common_backend). + + Note that versioning is not for the checkpoint content (which is application specific), + but for the checkpoint format itself. + """ sharded_backend: str sharded_backend_version: int = 1 @@ -23,10 +34,26 @@ class CheckpointingConfig: def check_is_distributed_checkpoint(checkpoint_dir): + """Checks if `metadata.json` exists in the checkpoint and is a valid config. + + Args: + checkpoint_dir: checkpoint directory + + Returns: + bool: True if `metadata.json` exists in the checkpoint and is a valid config. + """ return maybe_load_config(checkpoint_dir) is not None def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]: + """Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise + + Args: + checkpoint_dir: checkpoint directory + + Returns: + CheckpointingConfig (optional): None if checkpoint is not a valid distributed checkpoint + """ config_path = Path(checkpoint_dir, CONFIG_FNAME) if not config_path.exists(): return None @@ -36,6 +63,15 @@ def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]: def save_config(config: CheckpointingConfig, checkpoint_dir: str): + """Save given config to checkpoint directory. + + Args: + config: checkpoint config + checkpoint_dir: checkpoint directory + + Returns: + None + """ config_path = Path(checkpoint_dir, CONFIG_FNAME) with config_path.open('w') as f: json.dump(asdict(config), f) diff --git a/megatron/core/dist_checkpointing/dict_utils.py b/megatron/core/dist_checkpointing/dict_utils.py index c6baf4f11b..cd46134ea0 100644 --- a/megatron/core/dist_checkpointing/dict_utils.py +++ b/megatron/core/dist_checkpointing/dict_utils.py @@ -1,23 +1,45 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. -""" Utilities for operating with dicts and lists. """ +""" Utilities for operating with dicts and lists. + +All functions in this module handle nesting of dicts and lists. +Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed. +""" from collections import defaultdict -from typing import Any, Callable, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Tuple, TypeVar, Union +import numpy as np import torch +U, V = TypeVar("U"), TypeVar("V") + def extract_matching_values( - x: Union[dict, list], predicate: Callable + x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False ) -> Tuple[Union[dict, list], Union[dict, list]]: - """ Return matching and nonmatching values. Keeps hierarchy. """ + """Return matching and nonmatching values. Keeps hierarchy. + + Args: + x (Union[dict, list]) : state dict to process. Top-level argument must be a dict or list + predicate (object -> bool): determines matching values + return_lists_as_dicts (bool): if True, matching lists will be turned + into dicts, with keys indicating the indices of original elements. + Useful for reconstructing the original hierarchy. + """ + + def _set_elem(target, k, v): + if return_lists_as_dicts: + target[k] = v + else: + target.append(v) + if isinstance(x, dict): matching_vals = {} nonmatching_vals = {} for k, v in x.items(): if isinstance(v, (list, dict)): - match, nonmatch = extract_matching_values(v, predicate) + match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts) if match: matching_vals[k] = match if nonmatch or not v: @@ -26,25 +48,40 @@ def extract_matching_values( matching_vals[k] = v else: nonmatching_vals[k] = v - else: - assert isinstance(x, list) - matching_vals = [] - nonmatching_vals = [] - for v in x: + elif isinstance(x, list): # type: ignore + matching_vals = {} if return_lists_as_dicts else [] + nonmatching_vals = {} if return_lists_as_dicts else [] + for ind, v in enumerate(x): if isinstance(v, (list, dict)) and v: - match, nonmatch = extract_matching_values(v, predicate) + match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts) if match: - matching_vals.append(match) + _set_elem(matching_vals, ind, match) if nonmatch or not v: - nonmatching_vals.append(nonmatch) - elif predicate(v): - matching_vals.append(v) + _set_elem(nonmatching_vals, ind, nonmatch) else: - nonmatching_vals.append(v) + target = matching_vals if predicate(v) else nonmatching_vals + _set_elem(target, ind, v) + else: + raise ValueError(f'Unexpected top-level object type: {type(x)}') return matching_vals, nonmatching_vals def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]: + """Recursive diff of dicts. + + Args: + x1 (object): left dict + x2 (object): right dict + prefix (tuple): tracks recursive calls. Used for reporting differing keys. + + Returns: + Tuple[list, list, list]: tuple of: + - only_left: Prefixes present only in left dict + - only_right: Prefixes present only in right dict + - mismatch: values present in both dicts but not equal across dicts. + For tensors equality of all elems is checked. + Each element is a tuple (prefix, type of left value, type of right value). + """ mismatch = [] if isinstance(x1, dict) and isinstance(x2, dict): only_left = [prefix + (k,) for k in x1.keys() - x2.keys()] @@ -54,7 +91,8 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]: only_left.extend(_left) only_right.extend(_right) mismatch.extend(_mismatch) - elif isinstance(x1, list) and isinstance(x2, list): + elif isinstance(x1, list) or isinstance(x1, tuple) or isinstance(x1, np.ndarray): + assert type(x1) == type(x2) only_left = list(range(len(x1) - 1, len(x2) - 1, -1)) only_right = list(range(len(x1) - 1, len(x2) - 1, -1)) for i, (v1, v2) in enumerate(zip(x1, x2)): @@ -66,7 +104,17 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]: only_left = [] only_right = [] if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): - _is_mismatch = not torch.all(x1 == x2) + if x1.device != x2.device: + _is_mismatch = not torch.all(x1.cpu() == x2.cpu()) + else: + _is_mismatch = not torch.all(x1 == x2) + # TODO: change with concrete type that has both replica_id and data attrs + elif hasattr(x1, 'replica_id') and hasattr(x2, 'replica_id'): + assert type(x1) == type(x2) + only_left, only_right, mismatch = diff( + x1.data, x2.data, prefix + (type(x1),) + ) # type: ignore + _is_mismatch = False else: try: _is_mismatch = bool(x1 != x2) @@ -79,22 +127,8 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]: return only_left, only_right, mismatch -def inspect_keys_types(d: dict, prefix: Tuple = (), indent: int = 4): - print_indent = lambda: print(' ' * indent * len(prefix), end='') - for k, v in d.items(): - if isinstance(v, dict): - print_indent() - print(f'> {k}:') - inspect_keys_types(v, prefix + (k,), indent) - else: - print_indent() - if isinstance(v, torch.Tensor): - print(f'> {k}: {type(v)} of shape {v.shape}') - else: - print(f'> {k}: {type(v)}') - - def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4): + """Helper to print types of (nested) dict values.""" print_indent = lambda: print(' ' * indent * len(prefix), end='') if isinstance(x, dict): print() @@ -122,6 +156,7 @@ def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4): def nested_values(x: Union[dict, list]): + """Returns iterator over (nested) values of a given dict or list.""" x_iter = x.values() if isinstance(x, dict) else x for v in x_iter: if isinstance(v, (dict, list)): @@ -131,6 +166,7 @@ def nested_values(x: Union[dict, list]): def nested_items_iter(x: Union[dict, list]): + """Returns iterator over (nested) tuples (container, key, value) of a given dict or list.""" x_iter = x.items() if isinstance(x, dict) else enumerate(x) for k, v in x_iter: if isinstance(v, (dict, list)): @@ -140,16 +176,19 @@ def nested_items_iter(x: Union[dict, list]): def dict_map(f: Callable, d: dict): + """`map` equivalent for dicts.""" for sub_d, k, v in nested_items_iter(d): sub_d[k] = f(v) def dict_map_with_key(f: Callable, d: dict): + """`map` equivalent for dicts with a function that accepts tuple (key, value).""" for sub_d, k, v in nested_items_iter(d): sub_d[k] = f(k, v) -def dict_list_map_inplace(f: Callable, x: Union[dict, list]): +def dict_list_map_inplace(f: Callable[[U], V], x: Union[Dict, List, U]): + """Maps dicts and lists *in-place* with a given function.""" if isinstance(x, dict): for k, v in x.items(): x[k] = dict_list_map_inplace(f, v) @@ -160,7 +199,8 @@ def dict_list_map_inplace(f: Callable, x: Union[dict, list]): return x -def dict_list_map_outplace(f: Callable, x: Union[dict, list]): +def dict_list_map_outplace(f: Callable[[U], V], x: Union[Dict, List, U]) -> Union[Dict, List, V]: + """Maps dicts and lists *out-of-place* with a given function.""" if isinstance(x, dict): return {k: dict_list_map_outplace(f, v) for k, v in x.items()} elif isinstance(x, list): @@ -169,20 +209,27 @@ def dict_list_map_outplace(f: Callable, x: Union[dict, list]): return f(x) -def merge(x1: dict, x2: dict): +def merge(x1: Union[dict, list], x2: Union[dict, list], key: Tuple[Union[str, int], ...] = ()): + """Merges dicts and lists recursively.""" if isinstance(x1, dict) and isinstance(x2, dict): for k, v2 in x2.items(): if k not in x1: x1[k] = v2 else: - x1[k] = merge(x1[k], v2) + x1[k] = merge(x1[k], v2, key=key + (k,)) elif isinstance(x1, list) and isinstance(x2, list): if len(x1) != len(x2): - raise ValueError('Cannot merge two lists with different lengths') + raise ValueError( + f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, ' + f'encountered at level {key})' + ) for i, v2 in enumerate(x2): - x1[i] = merge(x1[i], v2) + x1[i] = merge(x1[i], v2, key=key + (i,)) else: - raise ValueError(f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}`') + raise ValueError( + f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` ' + f'(at level {key})' + ) return x1 @@ -192,6 +239,7 @@ def map_reduce( value_fn: Callable = lambda x: x, reduce_fn: Callable = lambda x: x, ) -> dict: + """Simple map-reduce implementation following `more_itertools.map_reduce` interface.""" res = defaultdict(list) for x in xs: res[key_fn(x)].append(value_fn(x)) diff --git a/megatron/core/dist_checkpointing/exchange_utils.py b/megatron/core/dist_checkpointing/exchange_utils.py new file mode 100644 index 0000000000..ea2cf6cc8a --- /dev/null +++ b/megatron/core/dist_checkpointing/exchange_utils.py @@ -0,0 +1,544 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +"""Utilities for exchanging data between ranks.""" + +import logging +from collections import defaultdict +from functools import reduce +from itertools import zip_longest +from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast + +import numpy as np +import torch + +from .core import CheckpointingException +from .dict_utils import nested_values +from .mapping import ShardedStateDict, ShardedTensor, is_main_replica +from .utils import _sharded_tensor_shard_id, _ShardId, debug_time + +# TODO: remove TE references once the TE bug is fixed +# Check if Transformer Engine has Float8Tensor class +HAVE_TE_FLOAT8TENSOR = False +try: + from transformer_engine.pytorch.float8_tensor import Float8Tensor + + HAVE_TE_FLOAT8TENSOR = True +except (ImportError, ModuleNotFoundError): + # Float8Tensor not found + pass + + +def is_float8tensor(tensor: torch.Tensor) -> bool: + """Check if a tensor is a Transformer Engine Float8Tensor""" + return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor) + + +logger = logging.getLogger(__name__) + + +class ShardDistribution(NamedTuple): + """Represents a distribution of ShardedTensors. + + Given distribution is valid only for a specific parallelization group, + which is implicit here (not referenced by this class). + + Args: + main_rank_for_shard (Dict[_ShardId, int]): specifies which rank should hold + the main replica for a given shard + shards_in_this_group (Set[_ShardId]): which shards have a main replica + in this parallelization group + shard_to_metadata (Dict[_ShardId, ShardedTensor]): maps ShardedTensor + identifier to the original ShardedTensor + all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks + need a given shard in a given parallelization group + """ + + main_rank_for_shard: Dict[_ShardId, int] + shards_in_this_group: Set[_ShardId] + shard_to_metadata: Dict[_ShardId, ShardedTensor] + all_ranks_for_shard: Dict[_ShardId, List[int]] + + +def _shard_size(sh_ten: ShardedTensor): + """Returns size in bytes of a given sharded tensor.""" + if sh_ten.flattened_range is None: + numel = np.product(sh_ten.local_shape) + else: + numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start + return numel * torch._utils._element_size(sh_ten.dtype) + + +def _get_empty_tensor_for_exchange( + shard_id: _ShardId, + needed_shards: Dict[_ShardId, ShardedTensor], + unneeded_shards: Dict[_ShardId, ShardedTensor], + loaded_tensors: Dict[_ShardId, torch.Tensor], +) -> Tuple[torch.Tensor, Optional[torch.device]]: + """Determines the empty tensor to use for exchange. + + If shard_id is needed by this rank, it will be in the `unloaded_shards`. + Otherwise, the metadata for this tensor can be found in `shard_to_metadata` + + Args: + shard_id (_ShardId): shard_id that will be exchanged + needed_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids + to metadata for shards needed by this rank + unneeded_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids + to metadata for shards that can be discarded after exchange + loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping where useful tensors + are placed in + + Returns: + Tuple[torch.Tensor, Optional[torch.device]]: empty CUDA tensor to be exchanged, + and the device of the original state dict tensor (if there was any) + """ + local_unloaded_sh_ten = needed_shards.get(shard_id) + if local_unloaded_sh_ten is None: + orig_device = None # this tensor will be discarded anyway + sh_ten = unneeded_shards[shard_id] + if sh_ten.data is None: + sh_ten.init_data('cuda') + tensor = sh_ten.data + sh_ten.data = None # won't be used. free memory + else: + tensor = sh_ten.data + if tensor.device.type == 'cpu': + tensor = torch.empty_like(tensor, device='cuda') + else: + local_unloaded_sh_ten.init_data('cuda') + orig_device = local_unloaded_sh_ten.data.device + tensor = local_unloaded_sh_ten.data + if tensor.device.type == 'cpu': + tensor = torch.empty_like(tensor, device='cuda') + loaded_tensors[shard_id] = tensor + return tensor, orig_device + + +T = TypeVar('T') + + +def distribute_shards_to_ranks( + shard_to_ranks: Dict[T, List[int]], shard_to_size: Dict[T, int], num_ranks: int +) -> Dict[T, int]: + """Computes uniform distribution of workload across ranks, based on sizes. + + Currently, the assignment is greedy, based on: + 1. Firstly, the coverage of each shard + (how many ranks the shard is available on; lower coverage is assigned first) + 2. Secondly, the size of each shard (larger size is assigned first) + 3. Finally, shard id for differentiation. + + Third step is added because we rely on the fact that + the assignment is deterministic on all ranks. + + Args: + shard_to_ranks (Dict[T, List[int]]): mapping of rank access to shards + shard_to_size (Dict[T, int]): sizes of each shard + num_ranks (int): number of ranks in the parallelization group + + Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work + to achieve maximal uniformity) + """ + shard_to_ranks = {k: tuple(v) for k, v in shard_to_ranks.items()} + shard_to_saving_rank = {} + rank_sizes = [(0, rank) for rank in range(num_ranks)] + + # start from tensors of lowest coverage, then go by tensor size from largest (hence minus size) + for shard_id, shard_ranks in sorted( + shard_to_ranks.items(), + key=lambda sh_id_ranks: ( + len(sh_id_ranks[1]), + -shard_to_size[sh_id_ranks[0]], + sh_id_ranks[0], + ), + ): + # assign greedily to the least occupied rank + size, rank = min((size, rank) for size, rank in rank_sizes if rank in shard_ranks) + + shard_to_saving_rank[shard_id] = rank + rank_sizes[rank] = (size + shard_to_size[shard_id], rank) + + logger.debug(f'distribute_shards_to_ranks distribution: {rank_sizes}') + + return shard_to_saving_rank + + +def determine_main_replica_uniform_distribution( + sharded_state_dict: ShardedStateDict, + parallelization_group: torch.distributed.ProcessGroup, + ignore_groups: bool = False, +) -> Optional[ShardDistribution]: + """Computes the save distribution. + + Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution` + which applies the computed save distribution. + + We rely on the fact that the assignment algorithm is deterministic on all ranks, + so there is no extra communication needed after metadata exchange. + + Args: + sharded_state_dict (ShardedStateDict): state dict to compute the distribution of + parallelization_group (ProcessGroup): distribution will be computed + within this process group + ignore_groups (bool, optional): whether the distribution defines groups. + This option is primarily used during loading, as it ensures that all replicas, + including non-main ones, are loaded by this parallelization group + Defaults to False. + + Returns (ShardDistribution, optional): distribution that can be used to apply the + parallelization. Returns None if the process_group is trivial (1 rank) + + """ + group_size = torch.distributed.get_world_size(group=parallelization_group) + if group_size <= 1: + return + local_shards = list( + sh_base + for sh_base in nested_values(sharded_state_dict) + if isinstance(sh_base, ShardedTensor) + ) + local_shards_no_data = [ten.without_data() for ten in local_shards] + + all_shards = [None] * torch.distributed.get_world_size(group=parallelization_group) + torch.distributed.all_gather_object( + all_shards, local_shards_no_data, group=parallelization_group + ) + + shard_to_ranks = defaultdict(list) + shard_to_size = {} + shard_to_metadata = {} + shards_in_this_parallelization_group: Set[_ShardId] = set() + for rank, rank_shards in enumerate(all_shards): + for sh_ten in rank_shards: + shard_id = _sharded_tensor_shard_id(sh_ten) + shard_to_ranks[shard_id].append(rank) + if shard_id not in shard_to_size: + shard_to_size[shard_id] = _shard_size(sh_ten) + shard_to_metadata[shard_id] = sh_ten + if is_main_replica(sh_ten.replica_id) or ignore_groups: + shards_in_this_parallelization_group.add(shard_id) + + shard_to_ranks = { + k: v for k, v in shard_to_ranks.items() if k in shards_in_this_parallelization_group + } + + shard_to_saving_rank = distribute_shards_to_ranks( + shard_to_ranks, shard_to_size, len(all_shards) + ) + + return ShardDistribution( + shard_to_saving_rank, + shards_in_this_parallelization_group, + shard_to_metadata, + shard_to_ranks, + ) + + +@torch.no_grad() +@debug_time(f"exchange_loaded_tensors_gather_rounds", logger) +def exchange_loaded_tensors_gather_rounds( + loaded_tensors: Dict[_ShardId, torch.Tensor], + unloaded_shards: Dict[_ShardId, ShardedTensor], + shard_distribution: ShardDistribution = None, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, +) -> Dict[_ShardId, torch.Tensor]: + """Exchange the tensors loaded by different ranks with several all_gather calls. + + Groups tensors by dtype, divide tensors that will be exchanged into rounds + and execute all_gather for tensors from each round. + + Note: the loading is distributed across ranks based on total loaded size + in bytes, so there is no guarantee that number of rounds needed for each + rank will be similar, which might result in a lot of almost empty + all_gathers. The solution would be to group all tensors into a one + bytes tensor and do a single all_gather (with similarly sized messages). + + Args: + loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to tensors already loaded by this rank. + unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to ShardedTensors that aren't loaded yet. + shard_distribution (ShardDistribution): distribution of all shards + parallelization_group (ProcessGroup, optional): process group used for load + distribution. Tensors will be exchanged within this group + + Returns: + Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors + needed by this rank to load a given state dict. Includes + previously loaded tensors (from `loaded_tensors` input) + """ + main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution + local_rank = torch.distributed.get_rank(group=parallelization_group) + + all_loaded_tensors = dict(loaded_tensors) + + # Group by dtype so that we all_gather tensors of the same dtype + for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str): + + with debug_time(f"dtype_{dtype}"): + # shards_by_rank maps rank to tensors loaded by this rank + shards_by_rank: List[List[torch.Tensor]] = [ + [] for _ in range(torch.distributed.get_world_size(group=parallelization_group)) + ] + for shard_id, rank in main_rank_for_shard.items(): + if len(all_ranks_for_shard[shard_id]) == 1: + assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], ( + f'When there is only 1 ranks that needs a given shard,' + f' it should be the loading rank.' + f' Got: needs [{all_ranks_for_shard[shard_id][0]}]' + f' vs loads [{main_rank_for_shard[shard_id]}]' + ) + # Skipping the exchange since only the loading rank needs this tensor + # TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` + # case, e.g. P2P exchange. Currently handling this case saves most of the + # work though. + continue + if shard_to_metadata[shard_id].dtype == dtype: + shards_by_rank[rank].append(shard_id) + + # Transpose `shards_by_rank` to form exchange rounds + shards_by_round = zip_longest(*shards_by_rank, fillvalue=None) + for round_idx, round_shard_ids in enumerate(shards_by_round): + round_tensors = [] + orig_devices = {} + for rank, shard_id in enumerate(round_shard_ids): + if shard_id is None: + # if no more useful data, the given rank will exchange empty tensor + local_ten = torch.empty(0, dtype=dtype, device='cuda') + orig_device = None + else: + assert isinstance(shard_id, tuple), type(shard_id) + if rank == local_rank: + assert shard_id in all_loaded_tensors, ( + shard_id, + all_loaded_tensors.keys(), + ) + orig_device = all_loaded_tensors[shard_id] + all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda() + local_ten = all_loaded_tensors[shard_id] + else: + local_ten, orig_device = _get_empty_tensor_for_exchange( + shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors + ) + # Because of a TE bug, we have to exchange a nominal dtype instead of FP8 + # It's ok to keep the nominal dtype after exchange, because TE will handle + # this during state dict load. + # TODO: remove it once the bug is fixed + if is_float8tensor(local_ten): + local_ten = local_ten.from_float8() + all_loaded_tensors[shard_id] = local_ten + + round_tensors.append(local_ten) + if orig_device is not None: + orig_devices[shard_id] = orig_device + + torch.distributed.all_gather( + list(round_tensors), + round_tensors[local_rank], + group=parallelization_group, + async_op=False, + ) + + # Move tensors back to CPU if originally was on CPU + for shard_id, orig_device in orig_devices.items(): + all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device) + + del round_tensors # remove tensor references + + return all_loaded_tensors + + +def exchange_loaded_tensors_gather_object( + loaded_tensors: Dict[_ShardId, torch.Tensor], + unloaded_shards: Dict[_ShardId, ShardedTensor], + shard_distribution: ShardDistribution, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, +) -> Dict[_ShardId, torch.Tensor]: + """Exchange the tensors loaded by different ranks with a simple all_gather_object call. + + This version can be used for debugging purposes do to its simplistic + implementation. Shouldn't be used if performance is important. + + Args: + loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to tensors already loaded by this rank. + unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to ShardedTensors that aren't loaded yet. + shard_distribution (ShardDistribution): distribution of all shards + parallelization_group (ProcessGroup, optional): process group used for load + distribution. Tensors will be exchanged within this group + + Returns: + Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors + needed by this rank to load a given state dict. Includes + previously loaded tensors (from `loaded_tensors` input) + + """ + all_loaded_tensors_list = [None] * torch.distributed.get_world_size(group=parallelization_group) + torch.distributed.all_gather_object( + all_loaded_tensors_list, loaded_tensors, group=parallelization_group + ) + all_loaded_tensors_list = cast(List[Dict[_ShardId, torch.Tensor]], all_loaded_tensors_list) + all_loaded_tensors = reduce(lambda x, y: {**x, **y}, all_loaded_tensors_list) + + # Error checks + if len(all_loaded_tensors) != sum(map(len, all_loaded_tensors_list)): + err_msg = 'Duplicate shard ids loaded by different ranks' + if torch.distributed.get_rank() == 0: + logger.error( + f'{err_msg}. Shards ids by rank:' + f' {[lt.keys() for lt in all_loaded_tensors_list]}' + ) + raise CheckpointingException(err_msg) + + return all_loaded_tensors + + +def exchange_loaded_objects_gather_object( + loaded_objects: Dict[_ShardId, Any] +) -> Dict[_ShardId, Any]: + """Exchange the objects loaded by different ranks with a simple all_gather_object call. + + Args: + loaded_objects (Dict[_ShardId, Any]): mapping from shard ids to objects + already loaded by this rank. + + Returns: + Dict[_ShardId, Any]: dictionary mapping shard ids to objects needed by this rank to + load a given state dict. + """ + all_loaded_objects_list = [None] * torch.distributed.get_world_size(group=None) + torch.distributed.all_gather_object(all_loaded_objects_list, loaded_objects, group=None) + all_loaded_objects_list = cast(List[Dict[_ShardId, Any]], all_loaded_objects_list) + all_loaded_objects = reduce(lambda x, y: {**x, **y}, all_loaded_objects_list) + + # Error checks + if len(all_loaded_objects) != sum(map(len, all_loaded_objects_list)): + err_msg = 'Duplicate shard ids loaded by different ranks' + if torch.distributed.get_rank() == 0: + logger.error( + f'{err_msg}. Shards ids by rank:' + f' {[lt.keys() for lt in all_loaded_objects_list]}' + ) + raise CheckpointingException(err_msg) + + return all_loaded_objects + + +@torch.no_grad() +@debug_time("exchange_loaded_tensors_broadcast", logger) +def exchange_loaded_tensors_broadcast( + loaded_tensors: Dict[_ShardId, torch.Tensor], + unloaded_shards: Dict[_ShardId, ShardedTensor], + shard_distribution: ShardDistribution, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, +) -> Dict[_ShardId, torch.Tensor]: + """Exchange the tensors loaded by different ranks by a series of broadcasts. + + For each rank for each loaded tensor do a broadcast to the whole group. + A reasonable tradeoff in terms of performance and simplicity. + + Args: + loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to tensors already loaded by this rank. + unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor + shard ids to ShardedTensors that aren't loaded yet. + shard_distribution (ShardDistribution): distribution of all shards + parallelization_group (ProcessGroup, optional): process group used for load + distribution. Tensors will be exchanged within this group + + Returns: + Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors + needed by this rank to load a given state dict. Includes + previously loaded tensors (from `loaded_tensors` input) + """ + main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution + local_rank = torch.distributed.get_rank(group=parallelization_group) + + all_loaded_tensors = dict(loaded_tensors) + + for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()): + if len(all_ranks_for_shard[shard_id]) == 1: + assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], ( + f'When there is only 1 ranks that needs a given shard,' + f' it should be the loading rank.' + f'Got: needs [{all_ranks_for_shard[shard_id][0]}]' + f' vs loads [{main_rank_for_shard[shard_id]}]' + ) + # Skipping the exchange since only the loading rank needs this tensor + # TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` case, + # e.g. P2P exchange. Currently handling this case saves most of the work though. + continue + if rank == local_rank: + assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys()) + orig_device = all_loaded_tensors[shard_id].device + local_ten = all_loaded_tensors[shard_id].cuda() + else: + local_ten, orig_device = _get_empty_tensor_for_exchange( + shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors + ) + + # Because of a TE bug, we have to exchange a nominal dtype instead of FP8 + # It's ok to keep the nominal dtype after exchange, because TE will handle + # this during state dict load. + # TODO: remove it once the bug is fixed + if is_float8tensor(local_ten): + local_ten = local_ten.from_float8() + all_loaded_tensors[shard_id] = local_ten + + global_src_rank = ( + rank + if parallelization_group == None + else torch.distributed.get_global_rank(parallelization_group, rank) + ) + # We can do async_op=True only if there is no CPU-copy follow-up + torch.distributed.broadcast( + local_ten, + src=global_src_rank, + group=parallelization_group, + async_op=orig_device is None, + ) + # Move tensor back to CPU if originally was on CPU + if orig_device is not None: + all_loaded_tensors[shard_id] = local_ten.to(orig_device) + del local_ten + + return all_loaded_tensors + + +def exchange_by_distribution( + loaded_tensors: Dict[_ShardId, torch.Tensor], + unloaded_shards: Dict[_ShardId, ShardedTensor], + shard_distribution: ShardDistribution, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, + exchange_algo='broadcast', +) -> Dict[_ShardId, torch.Tensor]: + """Exchange tensors loaded by different ranks using the specified exchange_algo. + + Args: + loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor + shard ids to tensors already loaded by this rank. + unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor + shard ids to ShardedTensors that aren't loaded yet. + shard_distribution (ShardDistribution): distribution of all shards + parallelization_group (ProcessGroup, optional): process group used for load + distribution. Tensors will be exchanged within this group + exchange_algo (str): The algorithm used for performing exchanges. + Defaults to 'broadcast'. + + Returns: + Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors + needed by this rank to load a given state dict. Includes + previously loaded tensors (from `loaded_tensors` input) + """ + + assert shard_distribution is not None, 'Expecting distribution to perform exchange' + if exchange_algo == 'gather_object': + exchange_fn = exchange_loaded_tensors_gather_object + elif exchange_algo == 'gather_rounds': + exchange_fn = exchange_loaded_tensors_gather_rounds + elif exchange_algo == 'broadcast': + exchange_fn = exchange_loaded_tensors_broadcast + else: + raise NotImplementedError(f'Unrecognized gather algorithm: {exchange_algo}') + return exchange_fn(loaded_tensors, unloaded_shards, shard_distribution, parallelization_group) diff --git a/megatron/core/dist_checkpointing/mapping.py b/megatron/core/dist_checkpointing/mapping.py index bf24764e83..da5c3295b7 100644 --- a/megatron/core/dist_checkpointing/mapping.py +++ b/megatron/core/dist_checkpointing/mapping.py @@ -1,54 +1,83 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. -""" Core library classes. """ +""" Core library classes for representing sharding of tensors and objects. -from dataclasses import dataclass, replace +The main expected usage is wrapping torch.Tensors in state dicts with +ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod). +""" + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field, replace from itertools import chain -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from .core import CheckpointingException +from .dict_utils import dict_list_map_inplace + +logger = logging.getLogger(__name__) # These type definitions are just hints to differentiate a plain model state # dict (StateDict) from a state dict with tensors replaced with ShardedTensors # (ShardedStateDict). StateDict = Dict[str, Any] +CommonStateDict = Dict[str, Any] ShardedStateDict = Dict[str, Any] ReplicaId = Union[int, Tuple[int, ...]] +class ShardedBase(ABC): + """Base class for ShardedTensor and ShardedStateDict.""" + + key: str + data: object + replica_id: ReplicaId + + @abstractmethod + def validate_metadata_integrity(self): + """Codifies the constraints on metadata attributes.""" + + @abstractmethod + def without_data(self) -> 'ShardedBase': + """Returns a new ShardedBase instance with data=None.""" + raise NotImplementedError + + @dataclass -class ShardedTensor: +class ShardedTensor(ShardedBase): """Represents a mapping between a local tensor and a global tensor. Global tensor is assumed to consist of many local tensors distributed between different processes. - Attributes: + Args: key: unique identifier of a global tensor data: local tensor data. Can be None only for consistency validation dtype: tensor dtype local_shape: local tensor shape global_shape: global tensor shape - global_offset: offset of a local tensor in a global tensor, specified - in number of tensor elements + global_offset: offset of a local tensor in a global tensor, + specified in number of tensor elements axis_fragmentations: global tensor fragmentation of each axis - replica_id: indicates given local tensor's replication wrt. local - tensors in different processes - prepend_axis_num: number of axes prepended to the local tensor - to reflect global tensor shape. - The behavior is similar to unsqueezing the local tensor. - allow_shape_mismatch: if True, during loading, the global shape of a - stored tensor does not have to match the expected global shape. - Useful for representing tensors with flexible shape, e.g. padded. - flattened_range: specifies a slice that should be applied to a flattened - tensor with `local_shape` in order to get the tensor stored as `data` + replica_id: indicates given local tensor's replication wrt. + local tensors in different processes + prepend_axis_num: number of axes prepended to the local tensor to + reflect global tensor shape. The behavior is similar to + unsqueezing the local tensor. + allow_shape_mismatch: if True, during loading, the global shape of + a stored tensor does not have to match the expected global shape. + Useful for representing tensors with flexible shape, + e.g. padded. + flattened_range: specifies a slice that should be applied to a + flattened tensor with `local_shape` in order to get + the tensor stored as `data` """ key: str - data: Optional[torch.Tensor] + data: Optional[torch.Tensor] = field(repr=False) dtype: torch.dtype local_shape: Tuple[int, ...] global_shape: Tuple[int, ...] @@ -59,7 +88,69 @@ class ShardedTensor: allow_shape_mismatch: bool = False flattened_range: Optional[slice] = None + def __post_init__(self): + self.validate_metadata_integrity() + + def validate_metadata_integrity(self) -> None: + """Codifies the constraints on metadata attributes. + + Meeting those constraints is guaranteed when instantiating a ShardedTensor + class with `from_rank_offsets` or `from_rank_offsets_flat` constructors. + + Returns: + None + """ + has_flattened_range = self.flattened_range is not None + if self.data is not None: + if self.data.dtype != self.dtype: + raise CheckpointingException( + f'Data dtype should match `dtype` attribute for {self}' + ) + if not has_flattened_range and self.data.shape != self.local_shape: + raise CheckpointingException( + f'Data shape should match `local_shape` attribute for {self}' + ) + if has_flattened_range: + if self.data.ndim != 1: + raise CheckpointingException(f'Data should be 1D for a flattened {self}') + real_data = self.data + try: + self.data = None + self.init_data(device='meta') + if self.data.shape != real_data.shape: + raise CheckpointingException( + f'Data shape {real_data.shape} doesnt match' + f' expected {self.data.shape} for {self}' + ) + finally: + self.data = real_data + + if len(self.global_shape) != len(self.global_offset): + raise CheckpointingException( + f'Global offset dimensions should be equal to global shape dimensions for {self}' + ) + if len(self.local_shape) + self.prepend_axis_num != len(self.global_shape): + raise CheckpointingException( + f'Local shape together with `prepend_axis_num` dimensions should be ' + f'equal to global shape dimensions for {self}' + ) + + for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape): + if off % sh != 0: + raise CheckpointingException( + f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.' + ) + + if has_flattened_range and self.flattened_range.step is not None: + raise CheckpointingException( + f'`step` argument in the flattened range of a ShardedTensor is not supported.' + ) + def global_slice(self) -> Tuple[Union[int, slice], ...]: + """ + Returns a tuple of int and slice objects representing a slice of the + global tensor that this ShardedTensor corresponds to. + """ assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num return tuple( chain( @@ -74,6 +165,10 @@ def global_slice(self) -> Tuple[Union[int, slice], ...]: ) def global_coordinates(self) -> Tuple[np.ndarray, ...]: + """ + Returns a tuple of np.ndarrays representing the coordinates of the global tensor + that this ShardedTensor corresponds to. + """ if self.flattened_range is None: raise CheckpointingException( f'`global_coordinates` is undefined for' @@ -92,6 +187,10 @@ def global_coordinates(self) -> Tuple[np.ndarray, ...]: return global_coords def local_coordinates(self) -> Tuple[np.ndarray, ...]: + """ + Returns a tuple of np.ndarrays representing the coordinates of the local tensor + that this ShardedTensor corresponds to. + """ if self.flattened_range is None: raise CheckpointingException( f'`local_coordinates` is undefined for' @@ -103,12 +202,28 @@ def local_coordinates(self) -> Tuple[np.ndarray, ...]: mask[self.flattened_range] = True return np.nonzero(mask.reshape(self.local_shape)) + def local_chunk_offset_in_global(self) -> Tuple[int, ...]: + """Offset of a local chunk in a global array of chunks. + + Returns: + Tuple[int, ...]: the offset of the whole local chunk in a global array of chunks. + """ + assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num + chunk_offset = list(self.global_offset[: self.prepend_axis_num]) + for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape): + assert off % sh == 0, str(self) + chunk_offset.append(off // sh) + return tuple(chunk_offset) + def max_allowed_chunks(self) -> Tuple[int, ...]: + """ + Returns the maximum allowed chunks for this ShardedTensor. + """ chunks = [] for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations): if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0: raise CheckpointingException( - f'Axis shape ({axis_sh}) not divisible' f' by axis fragmentation ({axis_fragm}' + f'Axis shape ({axis_sh}) not divisible by axis fragmentation ({axis_fragm}' ) axis_chunk_size = axis_sh // axis_fragm chunks.append(axis_chunk_size) @@ -125,35 +240,35 @@ def from_rank_offsets( *rank_offsets: Tuple[int, int, int], replica_id: ReplicaId = 0, prepend_axis_num: int = 0, - allow_shape_mismatch: bool = False, + flattened_range: None = None, + **init_kwargs, ): """Allows to construct the ShardedTensor given offset specified in process ranks. - Arguments: - key: unique key - data: local tensor data - rank_offsets: each tuple (axis, axis_rank_offset, axis_fragm) - says that if global tensor is divided into `axis_fragm` - fragment along `axis` axis, then local tensor data - corresponds to the `axis_rank_offset` chunk. - replica_id: see ShardedTensor - prepend_axis_num: see ShardedTensor - allow_shape_mismatch: see ShardedTensor + + Args: + key (str): unique key + data (torch.Tensor): local tensor data + rank_offsets (Tuple[int, int, int]): each tuple + (axis, axis_rank_offset, axis_fragm) says that if + global tensor is divided into `axis_fragm` fragment along `axis` + axis, then local tensor data corresponds to the `axis_rank_offset` chunk. + replica_id (ReplicaId): see ShardedTensor + prepend_axis_num (int): see ShardedTensor + flattened_range (None): must be None when using this constructor + init_kwargs: passed to ShardedTensor.__init__ """ + if flattened_range is not None: + raise ValueError( + 'Cannot instantiate a flat ShardedTensor with `from_rank_offsets` method.' + ' Use `from_rank_offsets_flat` instead' + ) global_offset = [0] * (data.ndim + prepend_axis_num) global_shape = ([1] * prepend_axis_num) + list(data.shape) axis_fragmentations = [1] * (data.ndim + prepend_axis_num) _seen_axis = set() for axis, axis_rank_offset, axis_fragm in rank_offsets: - assert axis >= 0 and axis_rank_offset >= 0 and axis_fragm >= 0, ( - axis, - axis_rank_offset, - axis_fragm, - ) - assert ( - axis_rank_offset < axis_fragm - ), 'Rank offset must be lower than axis fragmentation' - if axis in _seen_axis: - raise CheckpointingException('Duplicated axis specified') + if axis < 0 or axis_rank_offset < 0 or axis_fragm < 1 or axis_rank_offset >= axis_fragm: + raise CheckpointingException(f'Invalid rank offsets: {rank_offsets} for key {key}.') _seen_axis.add(axis) local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num] @@ -171,23 +286,223 @@ def from_rank_offsets( tuple(axis_fragmentations), replica_id, prepend_axis_num, - allow_shape_mismatch, + flattened_range=flattened_range, + **init_kwargs, ) - def __str__(self): - return f'{self.__class__.__name__}(key=\'{self.key}\')' + @classmethod + def from_rank_offsets_flat( + cls, + key: str, + data: torch.Tensor, + non_flat_local_shape: Tuple[int, ...], + *args, + flattened_range: Optional[slice] = None, + **kwargs, + ): + """Allows to construct a *flattened* ShardedTensor given offset specified in process ranks. + + Args: + key (str): + data (torch.Tensor): this should be a flattened data tensor + non_flat_local_shape (Tuple[int, ...]): expected local shape of a non-flat chunk + *args: passed unchanged to the `from_rank_offsets` constructor + flattened_range (slice): see ShardedTensor. Defaults to None, but must be set to + a non-None slice. + **kwargs: + + Returns: + ShardedTensor: constructed ShardedTensor instance + """ + if flattened_range is None: + raise CheckpointingException( + 'Cannot instantiate a non-flat ShardedTensor with `from_rank_offsets_flat` method.' + ' Use `from_rank_offsets` instead' + ) + if data.ndim != 1: + raise CheckpointingException( + f'Flattened ShardedTensor requires 1D data, got shape: {data.shape}' + ) + if flattened_range.stop - flattened_range.start != data.numel(): + raise CheckpointingException( + f'Flattened ShardedTensor data length ({data.numel()}) must meet the ' + f'slice length: {flattened_range.stop - flattened_range.start}' + ) + non_flat_data_meta = torch.empty(*non_flat_local_shape, dtype=data.dtype, device='meta') + sh_ten = cls.from_rank_offsets(key, non_flat_data_meta, *args, **kwargs) + instance = replace(sh_ten, data=data, flattened_range=flattened_range) + instance.validate_metadata_integrity() + return instance -def is_main_replica(replica_id): + def init_data(self, device: Union[str, torch.device], init_fn=torch.empty): + """ + Initialize the tensor data of this ShardedTensor. + + Only called if `data` attribute is None. + + Args: + device (Union[str, torch.device]): device to place the tensor on + init_fn (Callable, optional): function to use to initialize the tensor. + Defaults to `torch.empty`. + """ + if self.data is not None: + return + self.data = init_fn(self.local_shape, dtype=self.dtype, device=device) + if self.flattened_range is not None: + self.data = self.data.flatten()[self.flattened_range.start : self.flattened_range.stop] + + def narrow(self, dim: int, start: int, length: int) -> List['ShardedTensor']: + """This is an analogue of torch.narrow for ShardedTensors. + + Narrowing assumes that we narrow a local tensor on each rank. + This has consequences on local_shape, global_shape, global_offset, etc. + + Args: + dim (int): dimension to narrow. Doesn't include prepended axes. + start (int): start element + length (int): length of the slice + + Returns: + List[ShardedTensor]: narrowed ShardedTensors. For non-flat tensors, + the list will always have 1 element. For flat ShardedTensors the number of + elements varies depending on `dim` and on overlap, because flat + tensors must be contiguous. In particular the list can be empty. + """ + prepended_dim = dim + self.prepend_axis_num + local_length_along_dim = self.local_shape[dim] + + def _update_tuple(x, ind, val): + x = list(x) + x[ind] = val + return tuple(x) + + def _safe_div(x, y): + assert x % y == 0, (x, y) + return x // y + + # Decrease global shape and global offset by `length / local_length_along_dim` + assert ( + self.global_shape[prepended_dim] % local_length_along_dim == 0 + ), f'Only regular grid of local tensors is supported for narrowing, got: {self}' + assert ( + self.global_offset[prepended_dim] % local_length_along_dim == 0 + ), f'Only regular grid of local tensors is supported for narrowing, got: {self}' + global_shape = _update_tuple( + self.global_shape, + prepended_dim, + _safe_div(self.global_shape[prepended_dim] * length, local_length_along_dim), + ) + global_offset = _update_tuple( + self.global_offset, + prepended_dim, + _safe_div(self.global_offset[prepended_dim] * length, local_length_along_dim), + ) + + if self.flattened_range is None: + new_data = self.data.narrow(dim, start, length) + # always a single result tensor + return [ + replace( + self, + data=new_data, + local_shape=new_data.shape, + global_shape=global_shape, + global_offset=global_offset, + ) + ] + else: + if dim != 0: + raise CheckpointingException( + f'Narrowing along the first axis is supported for now only, got dim={dim}' + ) + + # If dim=0, we will always get 0 or 1 resulting tensor. + # If dim>1, in general there can be more result tensors (e.g. max 3 for dim=1) + + # For on original flat ShardedTensor of local shape [3, 4] and + # flattened_range=slice(5, 10), + # the X signs mark the actual (flat) data in `self.data` + # notice 12 (3*4) total "virtual" elements, out of which 5 is actual data. + # flat original: [.....XXXXX..] + + # If we narrow to start=1, length=1 in the original local shape dimensions, + # the overlapping flat slice would be: + # narrow to: [....XXXX....] + # flat overlap: [.....XXX....] + + # Now `data` is flattened and sliced, so we must compute local_shape manually + local_shape = _update_tuple(self.local_shape, dim, length) + other_dims_volume = np.prod( + _update_tuple(local_shape, dim, 1) + ) # 4 in the example above + volume_before_split = other_dims_volume * start # 4 in the example above + volume_of_split = other_dims_volume * length # 4 in the example above + + flat_slice_start_shifted = ( + self.flattened_range.start - volume_before_split + ) # 5 - 4 = 1 in the example above + flat_slice_stop_shifted = ( + self.flattened_range.stop - volume_before_split + ) # 10 - 4 = 6 in the example above + + # Find an intersection of + # (flat_slice_start_shifted, flat_slice_stop_shifted) vs (0, volume_of_split) + + if flat_slice_stop_shifted <= 0 or flat_slice_start_shifted >= volume_of_split: + return [] # no intersection + + # new_flattened_range = slice(1, 4) in the example above + new_flattened_range = slice( + max(flat_slice_start_shifted, 0), min(flat_slice_stop_shifted, volume_of_split) + ) + # Apply the intersection to the flattened data tensor. + # Compute start and slice appropriate length + intersection_slice_start = ( + new_flattened_range.start - flat_slice_start_shifted + ) # 0 in the example above + new_data = self.data[ + intersection_slice_start : intersection_slice_start + + new_flattened_range.stop + - new_flattened_range.start + ] + + return [ + replace( + self, + data=new_data, + local_shape=local_shape, + global_shape=global_shape, + global_offset=global_offset, + flattened_range=new_flattened_range, + ) + ] + + +def is_main_replica(replica_id: ReplicaId): + """Checks if given `replica_id` is considered as main. + + "Main" replica is: + - integer 0 + - or an iterable with all 0 elements + + It is the application responsibility to set correct replicas for sharded tensors. + + Args: + replica_id (Union[int, Tuple[int, ...]]): replica id + + Returns: + (bool): True for a "main" replica + """ if isinstance(replica_id, int): return replica_id == 0 return all(r == 0 for r in replica_id) -class LocalNonpersitentObject: +class LocalNonpersistentObject: """Object that should not be stored in a checkpoint, but restored locally. - Wrapping any object inside the state dict with LocalNonpersitentObject + Wrapping any object inside the state dict with LocalNonpersistentObject will result in: - during saving, this object will *not* be stored in the checkpoint - during loading, a local version of this object will be placed in a state dict @@ -197,11 +512,12 @@ def __init__(self, obj): self.obj = obj def unwrap(self): + """Returns the original object.""" return self.obj @dataclass -class ShardedObject: +class ShardedObject(ShardedBase): """Represents a mapping between a local object and a global object. Global object is assumed to consist of many local objects distributed @@ -211,14 +527,12 @@ class ShardedObject: sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor with atomic arbitrary typed elements. - Attributes: + Args: key: unique identifier of a global tensor data: local object data. Can be None only for consistency validation global_shape: global object shape - global_offset: offset of a local object in a global object, specified - in number of shards - replica_id: indicates local object replication wrt. local - objects in different processes + global_offset: offset of a local object in a global object, specified in number of shards + replica_id: indicates local object replication wrt. local objects in different processes """ key: str @@ -227,12 +541,179 @@ class ShardedObject: global_offset: Tuple[int, ...] replica_id: ReplicaId = 0 + def __post_init__(self): + self.validate_metadata_integrity() + + def validate_metadata_integrity(self): + if len(self.global_shape) != len(self.global_offset): + raise CheckpointingException( + f'Global offset dimensions should be equal to global shape dimensions for {self}' + ) + def without_data(self): return replace(self, data=None) @property def unique_key(self): - return f'{self.key}/shard_{".".join(map(str, self.global_offset))}_{".".join(map(str, self.global_shape))}' + """returns a unique key for this object""" + return ( + f'{self.key}/shard_' + f'{".".join(map(str, self.global_offset))}_' + f'{".".join(map(str, self.global_shape))}' + ) def __str__(self): return f'{self.__class__.__name__}(key=\'{self.key}\')' + + @classmethod + def empty_from_unique_key(cls, unique_key, replica_id: ReplicaId = 0) -> 'ShardedObject': + """Instantiates a ShardedObject from a unique key. + + Args: + unique_key: a string of the form + /shard__ + replica_id: indicates local object replication wrt. + local objects in different processes + + Returns: + a ShardedObject with data=None + """ + key, shard_key = unique_key.split('/') + shard_str, offset, shape = shard_key.split('_') + assert shard_str == 'shard' + offset = tuple(map(int, offset.split('.'))) + shape = tuple(map(int, shape.split('.'))) + if len(shape) + 1 == len(offset): + # This is a backward-compatible fix. We don't know the last + # element of global shape so set it to -1. + shape += (-1,) + return cls(key, None, shape, offset, replica_id) + + +FactoryBuildFn = Callable[[str, torch.Tensor, ReplicaId, Optional[slice]], ShardedStateDict] +FactoryMergeFn = Callable[[StateDict], torch.Tensor] + + +@dataclass +class ShardedTensorFactory(ShardedBase): + """Allows to apply transformations to tensors before/after serialization. + + The essence of those transformations is that they can be applied to + optimizer states the same way they are applied to the model params. + The ultimate state dict with sharded tensors must depend functionally on + `build_fn` arguments (key, data, replica_id, flattened_range), + which will be provided by the optimizer. + + Builder creates a sub-state-dict out of a tensor before saving, and merger + merges the corresponding state dict after loading. + + Args: + key (str): unique identifier of the factory + data (torch.Tensor): original model parameter that will be further + transformed by this factory + build_fn (callable): function that transforms the original tensor + to a sharded state dict + merge_fn (callable): function that transforms loaded subtree back + into a single tensor (inverse of `build_fn`) + replica_id (ReplicaId): indicates factory replication wrt. + factories in different processes + flattened_range (slice, optional): indicates additional flattening + applied to the ShardedTensors produced by the factory + """ + + key: str + data: torch.Tensor + build_fn: FactoryBuildFn + merge_fn: FactoryMergeFn + replica_id: ReplicaId = 0 + flattened_range: Optional[slice] = None + + def build(self): + """Builds a ShardedStateDict from the original tensor""" + return self.build_fn(self.key, self.data, self.replica_id, self.flattened_range) + + def validate_metadata_integrity(self): + """No reasonable checks can be applied""" + pass + + def without_data(self): + return replace(self, data=None) + + +def apply_factories(sharded_state_dict: ShardedStateDict): + """Turn ShardedTensorFactories into ShardedTensors *in-place*. + + Args: + sharded_state_dict (ShardedStateDict): state dict possibly + containing ShardedTensorFactory objects + + Returns: + None: state dict is modified in place + """ + + def apply(x): + if isinstance(x, ShardedTensorFactory): + x = x.build() + return x + + dict_list_map_inplace(apply, sharded_state_dict) + + +def apply_factory_merges( + x1: StateDict, x2: ShardedStateDict, key: Tuple[str, ...] = () +) -> StateDict: + """Apply merges defined by ShardedTensorFactories *in-place*. + + Args: + x1 (StateDict): state dict loaded from the checkpoint + x2 (ShardedStateDict): subset of `x1` (in terms of dict keys) + with ShardedTensorFactory + as (possibly nested) values that define how to + merge objects from the `x1` state dict + key (Tuple[str, ...]): current key in a recursive call. + Used only for reporting meaningful errors + + Returns: + StateDict: `x1` modified in-place + """ + if isinstance(x2, ShardedTensorFactory): + return x2.merge_fn(x1) + + # There rest is almost the same as the `merge` function from `dict_utils` + if isinstance(x1, dict) and isinstance(x2, dict): + for k, v2 in x2.items(): + if k not in x1: + raise ValueError( + f'Different dict keys encountered in `apply_factory_merges` ' + f'({x1.keys()} vs {x2.keys()})' + ) + else: + x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,)) + elif isinstance(x1, list) and isinstance(x2, list): + if len(x1) != len(x2): + err_msg = ( + f'Cannot merge two lists with different lengths ' + f'({len(x1)} and {len(x2)}, encountered at key {key})' + ) + logger.error(err_msg + f'\nx1: {x1}\nx2: {x2}') + raise ValueError(err_msg) + for i, v2 in enumerate(x2): + x1[i] = apply_factory_merges(x1[i], v2, key=key + (i,)) + elif isinstance(x1, list) and isinstance(x2, dict): + for k, v2 in x2.items(): + if not isinstance(k, int): + raise ValueError( + f'Invalid dict key {k} non-integer type encountered ' + f'in a list-dict merge at level {key}' + ) + if k >= len(x1): + raise ValueError( + f'Dict key {k} out of bound for list of length' + f'{len(x1)} (encountered at level {key})' + ) + x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,)) + else: + raise ValueError( + f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2} (at key {key})`' + ) + return x1 diff --git a/megatron/core/dist_checkpointing/optimizer.py b/megatron/core/dist_checkpointing/optimizer.py index 0d76676417..b3fcc7c645 100644 --- a/megatron/core/dist_checkpointing/optimizer.py +++ b/megatron/core/dist_checkpointing/optimizer.py @@ -1,25 +1,36 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. -""" Optimizer related helpers. """ +""" Helpers for defining sharding for optimizer states based on existing sharding +for model parameters. +""" import logging from copy import deepcopy from dataclasses import replace -from itertools import chain -from typing import Dict, Iterable, List, Tuple +from typing import Dict, Iterable, Tuple, Union logger = logging.getLogger(__name__) import torch +from megatron.core.utils import to_local_if_dtensor + from .dict_utils import nested_values -from .mapping import LocalNonpersitentObject, ShardedStateDict, ShardedTensor, StateDict -from .utils import extract_sharded_tensors +from .mapping import ( + LocalNonpersistentObject, + ShardedStateDict, + ShardedTensor, + ShardedTensorFactory, + StateDict, +) +from .utils import extract_sharded_tensors_and_factories def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]: + """Generate mapping from optimizer param to optimizer state id.""" param_mappings = {} for i, param in enumerate(optim_params_iter): + param = to_local_if_dtensor(param) if id(param) not in param_mappings: param_mappings[id(param)] = i return param_mappings @@ -27,10 +38,25 @@ def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) - def get_param_id_to_sharded_param_map( model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter] -) -> Dict[int, ShardedTensor]: - model_sharded_state_dict, _ = extract_sharded_tensors(model_sharded_state_dict) +) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: + """Generate mapping from optimizer state ids to model sharded parameters. + + Args: + model_sharded_state_dict: sharded state dict with all model sharded tensors + (can have any structure) + optim_params_iter: iterable which iterates over model parameters tracked by the optimizer. + The iteration must be in the same order as in the optimizer parameters. + + Returns: + Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: mapping from optimizer state ids + to model sharded parameters. + """ + model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict) id_to_sharded_param_map = {} param_to_id_map = get_optim_param_to_id_map(optim_params_iter) + # If using PyTorch FSDP2 the values in model_sharded_state_dict would + # have been converted to local tensors during initialization. + # See the make_(tp)_sharded_tensor_for_checkpoint functions. for ten in nested_values(model_sharded_state_dict): if id(ten.data) in param_to_id_map: id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten @@ -47,14 +73,31 @@ def get_param_id_to_sharded_param_map( def make_sharded_optimizer_tensor( - model_param: ShardedTensor, optim_param: torch.Tensor, prefix: str -) -> ShardedTensor: - assert ( - tuple(optim_param.shape) == model_param.local_shape - ), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})' - return replace( + model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str +) -> Union[ShardedTensor, ShardedTensorFactory]: + """Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param + + Args: + model_param (Union[ShardedTensor, ShardedTensorFactory]): model param + optim_param (torch.Tensor): corresponding optimizer param + prefix (str): optimizer prefix for the ShardedTensor or ShardedTensorFactory + + Returns: + Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter + """ + optim_param = to_local_if_dtensor(optim_param) + if isinstance(model_param, ShardedTensorFactory): + return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param) + + assert tuple(optim_param.shape) == model_param.local_shape, ( + f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ' + f'({model_param.local_shape})' + ) + sh_ten = replace( model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype ) + sh_ten.validate_metadata_integrity() + return sh_ten def optim_state_to_sharding_state( @@ -62,6 +105,24 @@ def optim_state_to_sharding_state( id_to_sharded_param_map: Dict[int, ShardedTensor], exclude_keys: Tuple[str] = (), ): + """Turn optimizer state dict to sharded state dict based on model state dict *in-place*. + + Can be used to add sharding information to most common optimizer state dict. + Creates separate ShardedTensors for each key in `optim_state_dict['state']` + (e.g. for torch.optim.Adam there will be separate tensors for `exp_avg` and `exp_avg_sq`) + + Args: + optim_state_dict (StateDict): optimizer state dict with + state parameters under `state` key and group hyperparameters under + `param_groups` -> `params` key. + id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids + to model sharded tensors. Can be generated with `get_param_id_to_sharded_param_map` + function. + exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict. + + Returns: + None: state dict is modified in place + """ sharded_state = {} for param_id, param_state in optim_state_dict['state'].items(): sharded_state[param_id] = {} @@ -77,5 +138,5 @@ def optim_state_to_sharding_state( optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups']) for group in optim_state_dict['param_groups']: - group['params'] = LocalNonpersitentObject(group['params']) + group['params'] = LocalNonpersistentObject(group['params']) optim_state_dict['state'] = sharded_state diff --git a/megatron/core/dist_checkpointing/serialization.py b/megatron/core/dist_checkpointing/serialization.py index a70e38b474..d655b365d5 100644 --- a/megatron/core/dist_checkpointing/serialization.py +++ b/megatron/core/dist_checkpointing/serialization.py @@ -1,33 +1,34 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +""" Entrypoints for saving and loading the distributed checkpoints. + +Functions `load` and `save` are equivalents of `torch.load` and `torch.save` +but expect torch.Tensors to be wrapped with classes from the `mapping module`. +Additionally, `load` expects the sharded state dict argument as a guidance for +loading the sharded tensors. +""" + import logging -import os -from collections import Counter, defaultdict -from itertools import chain from pathlib import Path -from typing import Iterable, List, Tuple, Union +from typing import Callable, Dict, Optional, Set, Tuple, Union -import numpy as np import torch -from .core import CheckpointingConfig, maybe_load_config, save_config -from .dict_utils import ( - dict_list_map_inplace, - diff, - extract_matching_values, - map_reduce, - merge, - nested_values, -) +from . import ShardedTensor +from .core import CheckpointingConfig, save_config +from .dict_utils import extract_matching_values, merge from .mapping import ( CheckpointingException, + CommonStateDict, ShardedObject, ShardedStateDict, - ShardedTensor, StateDict, - is_main_replica, + apply_factory_merges, ) +from .state_dict_utils import load_preprocess, save_preprocess +from .strategies.async_utils import AsyncRequest from .strategies.base import ( + AsyncSaveShardedStrategy, LoadCommonStrategy, LoadShardedStrategy, SaveCommonStrategy, @@ -35,92 +36,142 @@ StrategyAction, get_default_strategy, ) -from .utils import extract_sharded_tensors, extract_sharded_tensors_or_nonpersistent - -COMMON_STATE_FNAME = 'common.pt' +from .utils import extract_sharded_base +from .validation import ( + StrictHandling, + determine_global_metadata, + parse_strict_flag, + validate_integrity_and_strict_load, + validate_sharded_objects_handling, + verify_checkpoint_and_load_strategy, +) logger = logging.getLogger(__name__) +# flat state dict with sharded objects without any data +CkptShardedMetadata = Dict[str, Union[ShardedTensor, ShardedObject]] + + def load( sharded_state_dict: ShardedStateDict, checkpoint_dir: str, - sharded_strategy: Union[LoadShardedStrategy, None] = None, - common_strategy: Union[LoadCommonStrategy, None] = None, + sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None, + common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None, validate_access_integrity: bool = True, -) -> StateDict: + strict: Union[str, StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED, +) -> Union[StateDict, Tuple[StateDict, Set[str], Set[str]]]: """Loading entrypoint. - Arguments: + In the steps below, the following verbs refer to corresponding objects: + - load = load from checkpoint + - extract = extract from sharded_state_dict + - add = add to the final state dict + Steps: + 1. Load common state dict and form the base of the result state dict + 2. Apply factories to sharded_state_dict + 3. Extract LocalNonPersistentObject and add + 4. (optional) Extract ShardedObjects, load and add + 5. Extract ShardedBase, load, apply factory merges and add + + Args: sharded_state_dict (ShardedStateDict): state dict of the existing model populated with ShardedTensors. Used as a mapping to determine which parts of global tensors stored in the checkpoint should be loaded. checkpoint_dir (str): directory with the checkpoint - sharded_strategy (LoadShardedStrategy, optional): configures loading behavior for sharded tensors - common_strategy (LoadCommonStrategy, optional): configures loading behavior for common data + sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): + configures loading behavior for sharded tensors + common_strategy (LoadCommonStrategy, Tuple[str, int], optional): + configures loading behavior for common data validate_access_integrity (bool default = True): checks if each tensor shard is accessed exactly once (as main replica) by some process + strict (StrictHandling, str, optional): determines the behavior in case of a mismatch + between the requested sharded state dict and the checkpoint. See `StrictHandling` docs + for more details. Some values affect the return value of this function + (missing and unexpected keys are returned). + Defaults to `True` (StrictHandling.ASSUME_OK_UNEXPECTED) which doesn't + incur any performance overhead. Other recommended values + are: `False` (StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys + or `StrictHandling.RETURN_ALL` which returns all mismatch keys. + + Returns: + StateDict or Tuple[StateDict, Set[str], Set[str]]: in most cases only + the loaded state dict is returned. If `strict` flag was set to """ - if common_strategy is not None: - raise NotImplementedError('The only supported common strategy is torch') + sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( + checkpoint_dir, sharded_strategy, common_strategy + ) checkpoint_dir = Path(checkpoint_dir) - common_state_dict = load_common_state_dict(checkpoint_dir) - if not sharded_state_dict: - return common_state_dict - - sharded_objects, sharded_state_dict = load_sharded_objects(sharded_state_dict, checkpoint_dir) - merge(common_state_dict, sharded_objects) + common_state_dict = common_strategy.load_common(checkpoint_dir) - saved_config = maybe_load_config(checkpoint_dir) - if saved_config is None: - raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint') - - sharded_state_dict, _ = extract_sharded_tensors_or_nonpersistent(sharded_state_dict) - sharded_state_dict, nonpersistent_state_dict = extract_sharded_tensors(sharded_state_dict) - dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict) + sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( + sharded_state_dict + ) merge(common_state_dict, nonpersistent_state_dict) - if validate_access_integrity: - validate_sharding_integrity(nested_values(sharded_state_dict)) + # At this point we are only dealing with ShardedBase objects + sharded_state_dict, _ = extract_sharded_base(sharded_state_dict) - if sharded_strategy is None: - sharded_strategy = get_default_strategy( - StrategyAction.LOAD_SHARDED, - saved_config.sharded_backend, - saved_config.sharded_backend_version, + # Validation + ckpt_sharded_metadata = None + local_metadata, global_metadata = None, None + strict = parse_strict_flag(strict) + if StrictHandling.requires_explicit_ckpt_mismatch_check(strict): + ckpt_sharded_metadata = load_sharded_metadata( + str(checkpoint_dir), sharded_strategy, common_strategy ) - else: - # TODO: implement consistency checks here - pass + if validate_access_integrity or StrictHandling.requires_global_app_metadata(strict): + local_metadata, global_metadata = determine_global_metadata(sharded_state_dict) + + sharded_state_dict, missing_keys, unexpected_keys = validate_integrity_and_strict_load( + sharded_state_dict, + strict, + validate_access_integrity, + local_metadata, + global_metadata, + ckpt_sharded_metadata, + ) + + # ShardedBase loading + if not sharded_strategy.can_handle_sharded_objects: + validate_sharded_objects_handling(sharded_strategy, common_strategy) + sharded_objects_state_dict, sharded_state_dict = extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, ShardedObject) + ) + sharded_objects = common_strategy.load_sharded_objects( + sharded_objects_state_dict, checkpoint_dir + ) + merge(common_state_dict, sharded_objects) + loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir) merge(common_state_dict, loaded_state_dict) - return common_state_dict + loaded_state_dict = apply_factory_merges(common_state_dict, sh_ten_factories) -# TODO: implement it as common torch strategy -def load_common_state_dict(checkpoint_dir: Path): - return torch.load(Path(checkpoint_dir) / COMMON_STATE_FNAME, map_location='cpu') + if StrictHandling.requires_returning_mismatch_keys(strict): + return common_state_dict, missing_keys, unexpected_keys + else: + return common_state_dict -def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - sharded_objects, sharded_state_dict = extract_matching_values( - sharded_state_dict, lambda v: isinstance(v, ShardedObject) - ) +def load_common_state_dict(checkpoint_dir: Path) -> StateDict: + """Load common (non-sharded) objects state dict from the checkpoint. - def load_sharded_object(sh_obj: ShardedObject): - sh_obj.data = None - load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt') - loaded_obj = torch.load(load_path) - return loaded_obj + Args: + checkpoint_dir (Path): checkpoint directory - return dict_list_map_inplace(load_sharded_object, sharded_objects), sharded_state_dict + Returns: + StateDict: state dict with non-sharded objects from the checkpoint + """ + sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(str(checkpoint_dir)) + return common_strategy.load_common(checkpoint_dir) def load_tensors_metadata( checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None -) -> ShardedStateDict: +) -> CkptShardedMetadata: """Load tensors metadata from the checkpoint. Returns a dictionary similar to a sharded state dict, but note that @@ -132,40 +183,117 @@ def load_tensors_metadata( Concrete implementation depends on the loading strategy. If no strategy is given, a default for a given backend is used. - """ - saved_config = maybe_load_config(checkpoint_dir) - if saved_config is None: - raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint') - if sharded_strategy is None: - sharded_strategy = get_default_strategy( - StrategyAction.LOAD_SHARDED, - saved_config.sharded_backend, - saved_config.sharded_backend_version, - ) - else: - # TODO: implement consistency checks here - pass + Args: + checkpoint_dir (str): checkpoint directory to load from + sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata. + Defaults to None - in this case a default load strategy for a given checkpoint type + is used. + + Returns: + CkptShardedMetadata: flat state dict without data describing ShardedTensors + in the checkpoint + """ + sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( + checkpoint_dir, sharded_strategy + ) return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir)) -def load_plain_tensors(checkpoint_dir: str): - """Load checkpoint tensors without any sharding. +def load_sharded_metadata( + checkpoint_dir: str, + sharded_strategy: Union[LoadShardedStrategy, None] = None, + common_strategy: Union[LoadCommonStrategy, None] = None, +) -> CkptShardedMetadata: + """Load sharded metadata from the checkpoint. + + Similar to `load_tensors_metadata`, but includes also ShardedObjects. - NOTE: common state dict is NOT included.""" + Returns a dictionary similar to a sharded state dict, but note that + the dictionary keys are simply ShardedTensor keys (contrary to the + actual sharded state dicts where keys correspond to state dict keys). + + Dict values are ShardedTensors without any sharding (so, the only useful + information is tensors global shape and dtype). + + Concrete implementation depends on the loading strategy. If no strategy is + given, a default for a given backend is used. + + Args: + checkpoint_dir (str): checkpoint directory to load from + sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata. + Defaults to None - in this case a default load strategy for a given checkpoint type + is used. + common_strategy (LoadCommonStrategy, optional): common strategy to load metadata. + Defaults to None - in this case a default load strategy for a given checkpoint type is + used. This strategy won't be used unless `sharded_strategy` can't handle ShardedObjects + + Returns: + CkptShardedMetadata: flat state dict without data describing ShardedTensors + and ShardedObjects in the checkpoint + """ + sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy( + checkpoint_dir, sharded_strategy, common_strategy + ) + sharded_metadata = sharded_strategy.load_sharded_metadata(Path(checkpoint_dir)) + if not sharded_strategy.can_handle_sharded_objects: + validate_sharded_objects_handling(sharded_strategy, common_strategy) + common_metadata = common_strategy.load_sharded_metadata(Path(checkpoint_dir)) + sharded_metadata = merge(sharded_metadata, common_metadata) + return sharded_metadata + + +def load_plain_tensors(checkpoint_dir: str) -> StateDict: + """Load checkpoint tensors without any sharding and plain structure. + + NOTE: common state dict is NOT included. + + Args: + checkpoint_dir (str): checkpoint directory to load the tensors from. + + Returns: + StateDict: checkpoint state dict containing only torch.Tensors. + """ sharded_state_dict = load_tensors_metadata(checkpoint_dir) # Don't validate integrity because shards will be overlapped # if world_size > 1 (all processes load whole tensors) return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False) +# +# def load_plain_tensors_and_objects(checkpoint_dir: str) -> StateDict: +# """Load checkpoint tensors and objects without any sharding and plain structure. +# +# NOTE: state dict structure might be different than the one used for checkpoint saving. +# NOTE: common state dict is NOT included. +# +# Args: +# checkpoint_dir (str): checkpoint directory to load the state dict from. +# +# Returns: +# StateDict: complete checkpoint state dict without any sharding. +# """ +# sharded_state_dict = load_tensors_metadata(checkpoint_dir) +# # Don't validate integrity because shards will be overlapped +# # if world_size > 1 (all processes load whole tensors) +# return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False) + + +def remove_sharded_tensors(checkpoint_dir: str, key_prefix: str): + """determine the appropriate sharding strategy and delegate removal to the sharded strategy""" + sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(checkpoint_dir) + sharded_strategy.remove_sharded_tensors(checkpoint_dir, key_prefix) + + def save( sharded_state_dict: ShardedStateDict, checkpoint_dir: str, - sharded_strategy: Union[SaveShardedStrategy, None] = None, - common_strategy: Union[SaveCommonStrategy, None] = None, + sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None, + common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None, validate_access_integrity: bool = True, -): + async_sharded_save: bool = False, + preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None, +) -> Optional[AsyncRequest]: """Saving entrypoint. Extracts ShardedTensors from the given state dict. Rank 0 saves the @@ -173,15 +301,46 @@ def save( The ShardedTensors are saved according to a strategy specified by the config. - Arguments: + Steps: + 1. Apply factories + 2. Extract and discard LocalNonPersistentObject + 3. Extract all ShardedBase object + 4. Save all other objects to common.pt + 5. (optional) Extract and save ShardedObjects + 6. Save all ShardedBase objects + 7. Write metadata.json file with backend and version metadata. + + Step (6) can be performed asynchronously (see `async_sharded_save`), in this + case the actual save is embodied in the returned async request and can be + scheduled by the external caller. For async request, step (7) is added as + one of the finalization functions, so that metadata.json is written only + if the checkpoint is complete. + + Args: sharded_state_dict (ShardedStateDict): state dict of the populated with ShardedTensors. Used as a mapping to determine how local tensors should be saved as global tensors in the checkpoint. checkpoint_dir (str): directory to save the checkpoint to - sharded_strategy (SaveShardedStrategy, optional): configures sharded tensors saving behavior and backend - common_strategy (SaveCommonStrategy, optional): configures common data saving behavior and backend + sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional): + configures sharded tensors saving behavior and backend + common_strategy (SaveCommonStrategy, Tuple[str, int], optional): + configures common data saving behavior and backend validate_access_integrity (bool default = True): checks if each tensor shard is accessed - exactly once (as main replica) by some process + exactly once (as main replica) by some process. + It also makes sure the common state dict is consistant across all ranks + async_sharded_save (bool, optional): if True, for the sharded state dict part + an async save implementation will be called, with the AsyncRequest + being returned to the caller. Note that it is the caller responsibility to + actually schedule the async save. Defaults to False. + preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None): + A callable function that will preprocess the common state dict (i.e can be used to + remove keys that we expect to be different in the state dict). The function must not + modify the original state dict + + Returns: + AsyncRequest (optional): if `async_sharded_save` is True, returns + async request that should be scheduled by the caller of this function. + None otherwise. """ checkpoint_dir = Path(checkpoint_dir) @@ -200,164 +359,66 @@ def save( raise NotImplementedError('The only supported common strategy is torch') if sharded_strategy is None: - sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'zarr', 1) - - sharded_state_dict, state_dict = extract_sharded_tensors_or_nonpersistent(sharded_state_dict) - sharded_state_dict, _ = extract_sharded_tensors(sharded_state_dict) - sharded_tensors = list(nested_values(sharded_state_dict)) - if validate_access_integrity: - validate_sharding_integrity(sharded_tensors) - - _save_common_dict(state_dict, checkpoint_dir, True) - - sharded_strategy.save(sharded_tensors, checkpoint_dir) - save_config( - CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), checkpoint_dir + sharded_strategy = get_default_save_sharded_strategy() + if not isinstance(sharded_strategy, SaveShardedStrategy): + assert isinstance(sharded_strategy, tuple), type(sharded_strategy) + sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy) + + if common_strategy is None: + common_strategy = get_default_save_common_strategy() + if not isinstance(common_strategy, SaveCommonStrategy): + assert isinstance(common_strategy, tuple), type(common_strategy) + common_strategy = get_default_strategy(StrategyAction.SAVE_COMMON, *common_strategy) + + sharded_state_dict, state_dict = save_preprocess( + sharded_state_dict, validate_access_integrity, preprocess_common_before_consistancy_check ) + common_strategy.save_common(state_dict, checkpoint_dir) -# TODO: implement it as common torch strategy -def _save_common_dict( - state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False -): - common_state_dict = _extract_and_save_sharded_objects( - state_dict, checkpoint_dir, validate_consistency - ) - if torch.distributed.get_rank() == 0: - torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME) - if validate_consistency: - # TODO: implement checking consistency with rank 0 common dict on other ranks - pass - # torch.distributed.barrier() - # if not torch.distributed.get_rank() == 0: - # rank_0_state_dict = torch.load(checkpoint_dir / COMMON_STATE_FNAME) - # print(diff(common_state_dict, rank_0_state_dict)) - - -def _extract_and_save_sharded_objects( - state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False -): - sharded_objects, state_dict = extract_matching_values( - state_dict, lambda v: isinstance(v, ShardedObject) - ) - sharded_objects = list(nested_values(sharded_objects)) - if validate_consistency: - validate_objects_sharding_integrity(sharded_objects) - for sh_obj in sharded_objects: - if is_main_replica(sh_obj.replica_id): - save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt') - os.makedirs(save_path.parent, exist_ok=True) - torch.save(sh_obj.data, save_path) - return state_dict - - -def validate_sharding_integrity(sharded_tensors: Iterable[ShardedTensor]): - sharding = [ten.without_data() for ten in sharded_tensors] - all_sharding = [None] * torch.distributed.get_world_size() - torch.distributed.all_gather_object(all_sharding, sharding) - if torch.distributed.get_rank() != 0: - return - - key_shardings = defaultdict(list) - for rank, rank_shardings in enumerate(all_sharding): - for sharding in rank_shardings: - key_shardings[sharding.key].append((rank, sharding)) - for key, shardings in key_shardings.items(): - _validate_sharding_for_key(shardings) - - -def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]): - global_shape = rank_sharding[0][1].global_shape - local_shape = rank_sharding[0][1].local_shape - dtype = rank_sharding[0][1].dtype - has_flattened_range = rank_sharding[0][1].flattened_range is not None - for rank, sharding in rank_sharding: - assert sharding.dtype == dtype, (sharding.dtype, dtype) - assert sharding.global_shape == global_shape, (sharding.global_shape, global_shape) - assert sharding.local_shape == local_shape, (sharding.local_shape, local_shape) - assert (sharding.flattened_range is not None) == has_flattened_range, ( - (sharding.flattened_range is not None), - has_flattened_range, + if not sharded_strategy.can_handle_sharded_objects: + validate_sharded_objects_handling(sharded_strategy, common_strategy) + sharded_objects_state_dict, sharded_state_dict = extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, ShardedObject) ) + common_strategy.save_sharded_objects(sharded_objects_state_dict, checkpoint_dir) - shard_access_cnt = _compute_shards_access(rank_sharding) - if has_flattened_range: - map_reduce( - rank_sharding, - lambda x: x[1].global_offset, - lambda x: x[1], - _validate_sharding_for_key_flattened, - ) - else: - if not torch.all(shard_access_cnt == 1): - logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}') - raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}') - - -def _compute_shards_access(rank_sharding): - def chunk_offset(sharding): - assert len(sharding.global_offset) == len(sharding.local_shape) + sharding.prepend_axis_num - return tuple( - chain( - (off for off in sharding.global_offset[: sharding.prepend_axis_num]), - ( - off // sh - for off, sh in zip( - sharding.global_offset[sharding.prepend_axis_num :], sharding.local_shape - ) - ), + def metadata_finalize_fn(): + if torch.distributed.get_rank() == 0: + save_config( + CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), + checkpoint_dir, ) - ) + torch.distributed.barrier() - shard_access_cnt = torch.zeros( - rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu' - ) - for rank, sharding in rank_sharding: - if is_main_replica(sharding.replica_id): - shard_access_cnt[chunk_offset(sharding)] += 1 - # TODO: consider validating different replicas too - return shard_access_cnt - - -def _validate_sharding_for_key_flattened(tensors_by_shard): - all_slices = [] - local_shape = tensors_by_shard[0].local_shape - for sharding in tensors_by_shard: - assert sharding.local_shape == local_shape - sharding: ShardedTensor - if not is_main_replica(sharding.replica_id): - # TODO: this checks only saving (and loading replica_id=0) consistency - continue - - all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop)) - - starts, stops = map(np.asarray, zip(*sorted(all_slices))) - if ( - starts[0] != 0 - or stops[-1] != np.product(local_shape) - or not np.all(starts[1:] == stops[:-1]) - ): - logger.error( - f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}' - ) + if not async_sharded_save: + sharded_strategy.save(sharded_state_dict, checkpoint_dir) + metadata_finalize_fn() + return + + if not isinstance(sharded_strategy, AsyncSaveShardedStrategy): raise CheckpointingException( - f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}' + f'Cannot apply async_save to non-async strategy {sharded_strategy}' ) + async_request = sharded_strategy.async_save(sharded_state_dict, checkpoint_dir) + async_request.finalize_fns.append(metadata_finalize_fn) + return async_request -def validate_objects_sharding_integrity(sharded_objects: List[ShardedObject]): - """ Ensure uniqueness of saved objects. """ - local_sh_objs = [sh_obj.without_data() for sh_obj in sharded_objects] - all_sh_objs = [None] * torch.distributed.get_world_size() - torch.distributed.all_gather_object(all_sh_objs, local_sh_objs) - if torch.distributed.get_rank() != 0: - return - unique_keys = [ - sh_obj.unique_key - for sh_obj in chain.from_iterable(all_sh_objs) - if is_main_replica(sh_obj.replica_id) - ] - if len(unique_keys) != len(set(unique_keys)): - duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1} - logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}') - raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}') +def get_default_save_sharded_strategy( + backend: str = 'torch_dist', version: int = 1 +) -> SaveShardedStrategy: + """Get default save sharded strategy.""" + return get_default_strategy(StrategyAction.SAVE_SHARDED, backend, version) + + +def get_default_save_common_strategy( + backend: str = 'torch', version: int = 1 +) -> SaveCommonStrategy: + """Get default save common strategy.""" + return get_default_strategy(StrategyAction.SAVE_COMMON, backend, version) + + +def get_default_load_sharded_strategy(checkpoint_dir: str) -> LoadShardedStrategy: + """Get default load sharded strategy.""" + return verify_checkpoint_and_load_strategy(checkpoint_dir)[0] diff --git a/megatron/core/dist_checkpointing/state_dict_utils.py b/megatron/core/dist_checkpointing/state_dict_utils.py new file mode 100644 index 0000000000..ed2e434f51 --- /dev/null +++ b/megatron/core/dist_checkpointing/state_dict_utils.py @@ -0,0 +1,85 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Utilities for transforming state_dict.""" + +from typing import Callable + +from .dict_utils import dict_list_map_inplace, extract_matching_values +from .mapping import ( + CommonStateDict, + ShardedStateDict, + ShardedTensorFactory, + StateDict, + apply_factories, +) +from .utils import extract_nonpersistent, extract_sharded_base +from .validation import determine_global_metadata, validate_sharding_integrity + + +def save_preprocess( + sharded_state_dict: ShardedStateDict, + validate_access_integrity: bool = True, + preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None, +): + """Preprocesses the given state dictionary by applying factories, + discarding non-persistent data and extracting the common state dictionary. + Optionally, it can validate sharding integrity. + + Args: + sharded_state_dict (ShardedStateDict): The initial state dictionary to be preprocessed. + validate_access_integrity (bool): If True, triggers validation of sharding integrity. + preprocess_common_before_consistancy_check (callable, None): A callable function + that will preprocess the common state dict (i.e can be used to remove keys + that we expect to be different in the state dict) + + Returns: + Tuple[ShardedStateDict, dict]: + The preprocessed sharded state dictionary and the common state dictionary. + """ + apply_factories(sharded_state_dict) + _, sharded_state_dict = extract_nonpersistent(sharded_state_dict) + sharded_part, common_state_dict = extract_sharded_base(sharded_state_dict) + if validate_access_integrity: + preprocessed_common_state_dict = common_state_dict + if preprocess_common_before_consistancy_check: + preprocessed_common_state_dict = preprocess_common_before_consistancy_check( + common_state_dict + ) + validate_sharding_integrity( + determine_global_metadata(sharded_part)[1], + common_state_dict=preprocessed_common_state_dict, + ) + return sharded_part, common_state_dict + + +def load_preprocess(sharded_state_dict: ShardedStateDict): + """Preprocesses the given state dictionary by applying factories + and extracting non-persistent data, without modifying the original dictionary. + + Args: + sharded_state_dict (ShardedStateDict): + The initial state dictionary to be processed (remains unchanged). + + Returns: + Tuple[ShardedStateDict, dict, dict]: + - A preprocessed copy of the sharded state dictionary. + - A dictionary containing non-persistent state data. + - A dictionary of `ShardedTensorFactory` instances. + """ + # Create a copy of sharded_state_dict as the passed in state dict may have + # references that prevent tensors from being deallocated + sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True) + + sh_ten_factories, _ = extract_matching_values( + sharded_state_dict, + lambda x: isinstance(x, ShardedTensorFactory), + return_lists_as_dicts=True, + ) + apply_factories(sharded_state_dict) + + # Data inside sh_ten_factories no longer needed so delete them to reduce memory usage + dict_list_map_inplace(ShardedTensorFactory.without_data, sh_ten_factories) + # Non-persistent objects + nonpersistent_state_dict, sharded_state_dict = extract_nonpersistent(sharded_state_dict) + dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict) + return sharded_state_dict, nonpersistent_state_dict, sh_ten_factories diff --git a/megatron/core/dist_checkpointing/strategies/__init__.py b/megatron/core/dist_checkpointing/strategies/__init__.py index 7177d973cf..a786b8e84a 100644 --- a/megatron/core/dist_checkpointing/strategies/__init__.py +++ b/megatron/core/dist_checkpointing/strategies/__init__.py @@ -1,16 +1,7 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. """ Various loading and saving strategies """ +from megatron.core.dist_checkpointing.strategies.common import register_default_common_strategies -import logging - -logger = logging.getLogger(__name__) - -try: - import tensorstore - import zarr - - from .tensorstore import _import_trigger - from .zarr import _import_trigger -except ImportError: - logger.warning('Zarr-based strategies will not be registered because of missing packages') +# We load "common" strategies by default to be always available +register_default_common_strategies() diff --git a/megatron/core/dist_checkpointing/strategies/async_utils.py b/megatron/core/dist_checkpointing/strategies/async_utils.py new file mode 100644 index 0000000000..3285ac6cf7 --- /dev/null +++ b/megatron/core/dist_checkpointing/strategies/async_utils.py @@ -0,0 +1,543 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +This module provides an async utilities which allow to start +a checkpoint save process in the background. +""" +import gc +import logging +from abc import ABC, abstractmethod +from collections import deque +from contextlib import contextmanager +from queue import Empty +from time import sleep, time +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple + +import torch +from torch import multiprocessing as mp + +from ..utils import debug_time + +logger = logging.getLogger(__name__) + + +@contextmanager +def _disable_gc(): + """Temporarily disables GC.""" + gc_enabled = gc.isenabled() + try: + if gc_enabled: + gc.disable() + yield + finally: + if gc_enabled: + gc.enable() + + +class AsyncRequest(NamedTuple): + """Represents an async request that needs to be scheduled for execution. + + Args: + async_fn (Callable, optional): async function to call. None represents noop. + async_fn_args (Tuple): args to pass to `async_fn`. + finalize_fns (List[Callable]): list of functions to call to finalize the request. + These functions will be called synchronously after `async_fn` is done + *on all ranks*. + async_fn_kwargs (Tuple): kwargs to pass to `async_fn`. + preload_fn (Callable): preload function to stage tensors from GPU to Host. + This should be self-contained with a proper list of arguments with `partial`. + is_frozen (Bool): a flag to indicate this async request can be modified or not. + call_idx (int): index variable used to order async requests for synchronization + in preloading and writing tensors on the async caller + + """ + + async_fn: Optional[Callable] + async_fn_args: Tuple + finalize_fns: List[Callable] + async_fn_kwargs: Dict = {} + preload_fn: Callable = None + is_frozen: bool = False + call_idx: int = 0 + + def add_finalize_fn(self, fn: Callable) -> None: + """Adds a new finalize function to the request. + + Args: + fn (Callable): function to add to the async request. This function + will be called *after* existing finalization functions. + + Returns: + None + """ + if self.is_frozen: + raise RuntimeError('Cannot add finalization functions to a frozen AsyncRequest') + self.finalize_fns.append(fn) + + def execute_sync(self) -> None: + """Helper to synchronously execute the request. + + This logic is equivalent to what should happen in case of the async call. + """ + if self.async_fn is not None: + self.async_fn(*self.async_fn_args) + torch.distributed.barrier() + for finalize_fn in self.finalize_fns: + finalize_fn() + + def freeze(self) -> 'AsyncRequest': + """Freezes the async request, disallowing adding new finalization functions. + + Returns: + AsyncRequest: new async request with all same fields except for the + `is_frozen` flag. + """ + return self._replace(is_frozen=True) + + +class AsyncCaller(ABC): + """Wrapper around mp.Process that ensures correct semantic of distributed finalization. + + Starts process asynchronously and allows checking if all processes on all ranks are done. + """ + + @abstractmethod + def schedule_async_call(self, async_req: AsyncRequest) -> None: + """Schedule `async_req` with some process forking or reusing + persistent worker + + This method must be called on all ranks. + + Args: + async_req (AsyncRequest): `AsyncRequest` object containing to + start async process + """ + raise NotImplementedError("This should be implemented") + + @abstractmethod + def is_current_async_call_done(self, blocking: bool, no_dist: bool) -> bool: + """Check if async save is finished on all ranks. + + For semantic correctness, requires rank synchronization in each check. + This method must be called on all ranks. + + Args: + blocking (bool, optional): if True, will wait until the call is done + on all ranks. Otherwise, returns immediately if at least one rank + is still active. Defaults to False. + no_dist (bool, Optional): if True, training ranks simply check its + asynchronous checkpoint writer without synchronization. + + Returns: + bool: True if all ranks are done (immediately of after active wait + if `blocking` is True), False if at least one rank is still active. + + """ + raise NotImplementedError("This should be implemented") + + def sync_all_async_calls(self, is_alive: int) -> bool: + """Check if all ranks have completed async checkpoint writing + + Args: + is_alive (bool): if True, the current async request is not completed + + Returns: + bool: True if all ranks are done, False if at least one rank is still active. + + """ + ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device()) + torch.distributed.all_reduce(ten) + return ten[0] == 0 + + @abstractmethod + def close(self): + """Terminate the async caller at exit of an application or some termination conditions""" + logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller") + + def __del__(self): + self.close() + + +class TemporalAsyncCaller(AsyncCaller): + """Wrapper around mp.Process that ensures correct semantic of distributed finalization. + + Starts process asynchronously and allows checking if all processes on all ranks are done. + """ + + def __init__(self): + self.process: Optional[mp.Process] = None + self.start_time: Optional[float] = None + + @_disable_gc() + def schedule_async_call(self, async_req: AsyncRequest) -> None: + """Spawn a process with `async_fn` as the target. + + This method must be called on all ranks. + + Args: + async_fn (Callable, optional): async function to call. If None, + no process will be started. + async_req (AsyncRequest): `AsyncRequest` object containing to + start async process + """ + if async_req.async_fn is None: + return # nothing to do + + async_fn_args = list(async_req.async_fn_args) + if async_req.preload_fn: + # If there's a preload_fn in `async_req`, we call this func + # to do the defined action in `async_req.preload_fn` to + # stage GPU tensors to its defined destination + async_fn_args[1] = async_req.preload_fn() + + rank = torch.distributed.get_rank() + start_sync = time() + torch.cuda.synchronize() + end_sync = time() + logger.debug(f"rank: {rank}, takes {end_sync - start_sync} to finish D2H ") + + ctx = mp.get_context('fork') + self.start_time = time() + self.process = ctx.Process( + target=async_req.async_fn, args=async_fn_args, kwargs=async_req.async_fn_kwargs + ) + self.process.start() + init_time = time() + logger.debug(f"rank: {rank}, takes {init_time - self.start_time} to schedule async ckpt ") + + def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = False) -> bool: + """Check if async save is finished on all ranks. + + For semantic correctness, requires rank synchronization in each check. + This method must be called on all ranks. + + Args: + blocking (bool, optional): if True, will wait until the call is done + on all ranks. Otherwise, returns immediately if at least one rank + is still active. Defaults to False. + no_dist (bool, Optional): if True, training ranks simply check its + asynchronous checkpoint writer without synchronization. + + Returns: + bool: True if all ranks are done (immediately of after active wait + if `blocking` is True), False if at least one rank is still active. + """ + # The following takes the same overhead + # as torch.distributed.barrier (single integer all-reduce) + is_alive = int(self.process.is_alive()) if self.process is not None else 0 + is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive) + + if not is_done and blocking: + self.close() + is_done = True + return is_done + + def close(self): + if self.process: + logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process") + self.process.join() + self.process = None + logger.debug( + "TemporalAsyncCaller: Async process join finished " + f"after {time() - self.start_time:.2f}s from forking" + ) + self.start_time = None + + +class PersistentAsyncCaller(AsyncCaller): + """Wrapper around mp.Process that ensures correct semantic of distributed finalization. + + Starts process asynchronously and allows checking if all processes on all ranks are done. + """ + + def __init__(self): + self.process: mp.Process = None + self.start_time: Optional[float] = None + ctx = mp.get_context('spawn') + # main queue to deliver `AsyncRequest` from host to the ckpt worker + self.queue: mp.JoinableQueue = ctx.JoinableQueue() + # Queue used to synchronize for the completion of preloading tensors to host + # between a trainer and ckpt worker + self.preload_q: mp.JoinableQueue = ctx.JoinableQueue() + # Queue used to inform trainer when the saving is completed + self.comp_q: mp.Queue = ctx.Queue() + self.cur_item: int = None + self.cur_idx: int = -1 + + def schedule_async_call(self, async_req: AsyncRequest) -> None: + """Put `AsyncRequest` to the Persistent Async Caller + + This method must be called on all ranks. + + Args: + async_fn (Callable, optional): async function to call. If None, + no process will be started. + async_req (AsyncRequest): `AsyncRequest` object containing to + schedule a checkpointing request + """ + if async_req.async_fn is None: + return # nothing to do + + start_sync = end_sync = None + + self.start_time = time() + if self.process is None: + ctx = mp.get_context('spawn') + logger.info( + f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Starting Async Caller" + ) + self.process: mp.Process = ctx.Process( + target=PersistentAsyncCaller.async_loop, + args=( + torch.distributed.get_rank(), + self.queue, + self.preload_q, + self.comp_q, + logger.getEffectiveLevel(), + ), + ) + self.process.start() + logger.info( + f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Started Async Caller" + ) + + if async_req.preload_fn: + self.preload_q.put(async_req.call_idx) + self.queue.put(async_req) + logger.debug(f"rank: {torch.distributed.get_rank()}, put {async_req.call_idx}") + + if async_req.preload_fn: + start_sync = time() + # Synchronize for pre-staging tensors + self.preload_q.join() + end_sync = time() + logger.debug( + f"rank: {torch.distributed.get_rank()}, " + f"takes {end_sync - start_sync} to finish D2H " + ) + + init_time = time() + logger.debug( + f"rank: {torch.distributed.get_rank()}, takes {init_time - self.start_time} " + "to schedule async ckpt " + ) + + def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = False) -> bool: + """Check if async save is finished on all ranks. + + For semantic correctness, requires rank synchronization in each check. + This method must be called on all ranks. + + Args: + blocking (bool, optional): if True, will wait until the call is done + on all ranks. Otherwise, returns immediately if at least one rank + is still active. Defaults to False. + no_dist (bool, Optional): if True, training ranks simply check its + asynchronous checkpoint writer without synchronization. + + Returns: + bool: True if all ranks are done (immediately of after active wait + if `blocking` is True), False if at least one rank is still active. + """ + + is_alive: bool = False + + if self.process: + while self.cur_item is None: + try: + # Retrieve comp call_idx without waiting + self.cur_item = self.comp_q.get_nowait() + except Empty: + # This method is called after any `AsyncRequest` is pushed to the main loop + # So, the background writing is still active + # before the worker put call_idx to `comp_q` + if not blocking: + is_alive = True + break + sleep(0.1) + + if self.cur_item is not None: + logger.debug( + f"rank: {torch.distributed.get_rank()}, item: {self.cur_item}" + f" is completed, {is_alive}" + ) + + is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive) + # This is set to False when blocking == False so this routine is called again + # to simply call `sync_all_async_calls` to check if other ranks complete the writing + if is_done: + # The current request is completed globally. Reset the current item for polling. + logger.debug( + f"rank: {torch.distributed.get_rank()}, item: {self.cur_item}" + f" is completed globally, {is_done}" + ) + self.cur_item = None + + return is_done + + def close(self): + logger.info( + f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller" + ) + if self.process: + self.queue.put('DONE') + self.queue.join() + self.process.join() + self.process = None + + @staticmethod + @_disable_gc() + def async_loop( + rank: int, + queue: mp.JoinableQueue, + preload_q: mp.JoinableQueue, + comp_q: mp.Queue, + log_level: int = logging.INFO, + ): + """Main function for the persistent checkpoint worker + + The persisent worker is created once and terminated at exit or + when application calls `close()` explictily + + This routine receives `AsyncRequest` and does `preload_fn` first and + put the integer value in `preload_q` to inform the trainer to proceed. + When the `async_fn` from the request` is completed (background saving is done), + it puts a integer value to `comp_q` to notify the trainer the completion. + + Args: + rank (int): the rank of the trainer where the persistent worker is created. + queue (mp.JoinableQueue): the main queue used to receive `AsyncRequest + from the training rank + preload_q (mp.JoinableQueue): a queue to inform trainer that preloading of tensors + from GPU to Host or dedicated location is completed + comp_q (mp.Queue): a queue to inform the training rank the completion of scheduled + async checkpoint request + log_level (int, Optional): an integer to set log-level in this spawned process + to get aligned with the training rank's logging level + + """ + logger = logging.getLogger(__name__) + logger.setLevel(log_level) + logger.info(f"PersistentAsyncCaller: persistent ckpt worker for {rank} has started") + while True: + item = queue.get() + if isinstance(item, str) and item == 'DONE': + queue.task_done() + break + elif isinstance(item, AsyncRequest): + async_fn_args = list(item.async_fn_args) + if item.preload_fn: + call_idx = preload_q.get() + # the 2nd arg is state dict + async_fn_args[1] = item.preload_fn() + logger.debug(f"{rank} has completed D2H of {call_idx}") + preload_q.task_done() + item.async_fn(*async_fn_args, **item.async_fn_kwargs) + logger.debug(f"{rank} has completed saving {item.call_idx}") + comp_q.put(item.call_idx) + queue.task_done() + + logger.info(f"PersistentAsyncCaller: persistent ckpt worker for {rank} has terminated") + + +class _ActiveAsyncRequest(NamedTuple): + """Helper to represent an active async call. + + Args: + idx (int): index of the call (starting from 0) + async_caller (DistributedAsyncCaller): async caller instance that represents + the async process handling the async request + async_request (AsyncRequest): async request that is being called + """ + + idx: int + async_caller: AsyncCaller + async_request: AsyncRequest + + +class AsyncCallsQueue: + """Manages a queue of async calls. + + Allows adding a new async call with `schedule_async_request` and finalizing + active calls with `maybe_finalize_async_calls`. + """ + + def __init__(self, persistent: bool = False): + self.async_calls: deque[_ActiveAsyncRequest] = deque([]) + self.call_idx: int = -1 + self.persistent: bool = persistent + self.persistent_caller: AsyncCaller = None + + def _get_async_caller(self): + if not self.persistent: + return TemporalAsyncCaller() + if self.persistent_caller is None: + self.persistent_caller = PersistentAsyncCaller() + return self.persistent_caller + + def schedule_async_request(self, async_request: AsyncRequest) -> int: + """Start a new async call and add it to a queue of active async calls. + + This method must be called on all ranks. + + Args: + async_request (AsyncRequest): async request to start. + + Returns: + int: index of the async call that was started. + This can help the user keep track of the async calls. + """ + self.call_idx += 1 + async_caller = self._get_async_caller() + # Backward compatibility for local checkpointing built with the old AsyncRequest + if len(async_request._fields) != len(AsyncRequest._fields): + async_request = AsyncRequest(**async_request._asdict()) + + async_request = async_request._replace(call_idx=self.call_idx) + finalize_fns = async_request.finalize_fns + async_request = async_request._replace(finalize_fns=None) + async_request = async_request.freeze() + async_caller.schedule_async_call(async_request) + self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, finalize_fns)) + return self.call_idx + + def maybe_finalize_async_calls(self, blocking=False, no_dist=False) -> List[int]: + """Finalizes all available calls. + + This method must be called on all ranks. + + Args: + blocking (bool, optional): if True, will wait until all active requests + are done. Otherwise, finalizes only the async request that already + finished. Defaults to False. + Returns: + List[int]: list of indices (as returned by `schedule_async_request`) + of async calls that have been successfully finalized. + """ + call_idx_finalized = [] + while self.async_calls: + next_async_done = self.async_calls[0].async_caller.is_current_async_call_done( + blocking, no_dist + ) + if not next_async_done: + break + with debug_time("finalize", logger): + call_idx, _, finalize_fns = self.async_calls.popleft() + ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device()) + torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX) + assert ten.item() == call_idx, 'Unmatched async calls. ' + 'That probably means not all ranks are participating in async finalization' + for finalize_fn in finalize_fns: + finalize_fn() + call_idx_finalized.append(call_idx) + return call_idx_finalized + + def get_num_unfinalized_calls(self): + """Get the number of active async calls.""" + return len(self.async_calls) + + def close(self): + """Finalize all calls upon closing.""" + self.maybe_finalize_async_calls(blocking=True) + if self.persistent and self.persistent_caller: + self.persistent_caller.close() diff --git a/megatron/core/dist_checkpointing/strategies/base.py b/megatron/core/dist_checkpointing/strategies/base.py index 3989ea74a2..77eb325b92 100644 --- a/megatron/core/dist_checkpointing/strategies/base.py +++ b/megatron/core/dist_checkpointing/strategies/base.py @@ -1,90 +1,228 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +""" Strategies base interfaces. """ + from abc import ABC, abstractmethod from collections import defaultdict from enum import Enum from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, DefaultDict, Union -from ..mapping import CheckpointingException, ShardedStateDict, ShardedTensor, StateDict +from ..mapping import CheckpointingException, ShardedStateDict, StateDict +from .async_utils import AsyncCallsQueue, AsyncRequest class StrategyAction(Enum): + """Specifies save vs load and sharded vs common action.""" + LOAD_COMMON = 'load_common' LOAD_SHARDED = 'load_sharded' SAVE_COMMON = 'save_common' SAVE_SHARDED = 'save_sharded' -default_strategies = defaultdict(dict) +default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict) + +async_calls = AsyncCallsQueue() def get_default_strategy(action: StrategyAction, backend: str, version: int): + """Retrieves a default strategy for a given action, backend and version.""" + error_hint: str = None + try: + if backend == 'zarr': + error_hint = ' Please install `zarr` and `tensorstore!=0.1.46` packages' + from .tensorstore import register_default_tensorstore_strategies + + register_default_tensorstore_strategies() + from .zarr import register_default_zarr_strategies + + register_default_zarr_strategies() + elif backend == 'torch_dist': + error_hint = ' Please use PyTorch version >=2.1' + from .torch import register_default_torch_strategies + + register_default_torch_strategies() + except ImportError as e: + raise CheckpointingException( + f'Cannot import a default strategy for: {(action.value, backend, version)}. ' + f'Error: {e}. Hint: {error_hint}' + ) from e try: return default_strategies[action.value][(backend, version)] except KeyError as e: - hint = '' - if backend == 'zarr': - try: - import tensorstore - import zarr - except ImportError: - hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages' raise CheckpointingException( - f'Cannot find a default strategy for: {(action.value, backend, version)}.{hint}' + f'Cannot find a default strategy for: {(action.value, backend, version)}' ) from e +def register_default_strategy( + action: StrategyAction, + backend: str, + version: int, + strategy: Union['SaveStrategyBase', 'LoadStrategyBase'], +): + """Adds a given strategy to the registry of default strategies. + + Args: + action (StrategyAction): specifies save/load and sharded/common + backend (str): backend that the strategy becomes a default for + version (int): version that the strategy becomes a default for + strategy (SaveStrategyBase, LoadStrategyBase): strategy to register + """ + default_strategies[action.value][(backend, version)] = strategy + + class LoadStrategyBase(ABC): + """Base class for a load strategy. Requires implementing checks for compatibility with a + given checkpoint version.""" + @abstractmethod - def check_backend_compatibility(self, loaded_version): + def check_backend_compatibility(self, loaded_backend): + """Verifies if this strategy is compatible with `loaded_backend`.""" raise NotImplementedError @abstractmethod def check_version_compatibility(self, loaded_version): + """Verifies if this strategy is compatible with `loaded_version`.""" raise NotImplementedError + @property + def can_handle_sharded_objects(self): + """Returns whether or not this strategy can handle loading ShardedObjects.""" + return False + class SaveStrategyBase(ABC): + """Base class for a save strategy. Requires defining a backend type and + version of the saved format.""" + def __init__(self, backend: str, version: int): self.backend = backend self.version = version + @property + def can_handle_sharded_objects(self): + """Returns whether or not this strategy can handle saving ShardedObjects.""" + return False + + def __str__(self): + return f'{self.__class__.__name__}({self.backend}, {self.version})' + class LoadCommonStrategy(LoadStrategyBase): + """Load strategy for common (non-sharded) objects""" + + @abstractmethod + def load_common(self, checkpoint_dir: Path): + """Load common part of the checkpoint.""" + raise NotImplementedError + @abstractmethod - def load(self, checkpoint_dir: Path): + def load_sharded_objects( + self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path + ): + """Load sharded objects from the checkpoint.""" + raise NotImplementedError + + def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: + """Load just the metadata from the checkpoint.""" + if not self.can_handle_sharded_objects: + return {} raise NotImplementedError class LoadShardedStrategy(LoadStrategyBase): + """Load strategy for sharded tensors""" + @abstractmethod def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Load the sharded part of the checkpoint.""" raise NotImplementedError @abstractmethod def load_tensors_metadata(self, checkpoint_dir: Path): - """Load tensors metadata from the checkpoint. + """Load tensors metadata from the checkpoint for ShardedTensors. Returns a dictionary similar to a sharded state dict, but note that the dictionary keys are simply ShardedTensor keys (contrary to the actual sharded state dicts where keys correspond to state dict keys). - Dict values are ShardedTensors without any sharding (so, the only useful - information is tensors global shape and dtype). + Dict values are ShardedTensors without any data and sharding (so, the + only useful information is tensors global shape and dtype). + """ + raise NotImplementedError( + f'Loading only tensors metadata not implemented for {self.__class__.__name__}' + ) + + def load_sharded_metadata(self, checkpoint_dir: Path): + """Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects. + + Returns a dictionary similar to a sharded state dict, but note that + the dictionary keys are simply sharded keys (contrary to the + actual sharded state dicts where keys correspond to state dict keys). + + Dict values are ShardedTensors or ShardedObjects without any data and sharding. """ + if not self.can_handle_sharded_objects: + return self.load_tensors_metadata(checkpoint_dir) raise NotImplementedError( - f'{self.__class__.__name__} doesnt allow loading only sharded metadata' + f'Loading only sharded metadata not implemented for {self.__class__.__name__}' ) + def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str): + """Remove all tensors whose key starts with key_prefix""" + raise NotImplementedError + class SaveCommonStrategy(SaveStrategyBase): + """Save strategy for common (non-sharded) objects""" + @abstractmethod - def save(self, common_state_dict: StateDict, checkpoint_dir: Path): + def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path): + """Save common part of the state dict.""" + raise NotImplementedError + + def save_sharded_objects( + self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path + ): + """Save sharded objects from the state dict.""" raise NotImplementedError class SaveShardedStrategy(SaveStrategyBase): + """Save strategy for sharded tensors""" + @abstractmethod - def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path): + def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Save the sharded part of the state dict.""" raise NotImplementedError + + +class AsyncSaveShardedStrategy(SaveShardedStrategy): + """Save strategy suitable for async save.""" + + @abstractmethod + def async_save( + self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path + ) -> AsyncRequest: + """Perform preparation and return an AsyncRequest to the external caller. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to save + checkpoint_dir (Path): checkpoint target directory + + Returns: + AsyncRequest: represents the async save function and finalization function. + It is the caller responsibility to actually schedule the async save. + """ + raise NotImplementedError + + def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Each async strategy can be trivially used as a sync strategy.""" + async_request = self.async_save(sharded_state_dict, checkpoint_dir) + # multiprocessing routines may cause issue when called on parent process + # We keep this verbose call for now + global async_calls + async_calls.schedule_async_request(async_request) + async_calls.maybe_finalize_async_calls(blocking=True) diff --git a/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py b/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py new file mode 100644 index 0000000000..f8cb0326e2 --- /dev/null +++ b/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +""" FS Reader with metadata cached support. """ + +import os +from typing import Union + +from torch.distributed.checkpoint import FileSystemReader, Metadata + + +class CachedMetadataFileSystemReader(FileSystemReader): + """ + Extends FileSystemReader to cache metadata for improved performance. + + Attributes: + _cached_metadata (Metadata or None): Cached metadata from the file system. + """ + + def __init__(self, path: Union[str, os.PathLike]) -> None: + """ + Initialize with file system path. + + Args: + path (Union[str, os.PathLike]): Path to the checkpoint directory or file. + """ + super().__init__(path=path) + self._cached_metadata = None + + def read_metadata(self) -> Metadata: + """ + Read metadata from file system, caching for subsequent calls. + + Returns: + Metadata: Checkpoint metadata. + """ + if self._cached_metadata is None: + self._cached_metadata = super().read_metadata() + return self._cached_metadata diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py new file mode 100644 index 0000000000..4582f552fb --- /dev/null +++ b/megatron/core/dist_checkpointing/strategies/common.py @@ -0,0 +1,157 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" Common strategies. """ + +import logging +import os +from pathlib import Path + +import torch + +from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict +from megatron.core.dist_checkpointing.strategies.base import ( + SaveCommonStrategy, + StrategyAction, + register_default_strategy, +) + +from ..dict_utils import dict_list_map_inplace, nested_values +from ..mapping import CheckpointingException, ShardedObject, is_main_replica +from ..strategies.base import LoadCommonStrategy + +COMMON_STATE_FNAME = 'common.pt' + +logger = logging.getLogger(__name__) + + +def register_default_common_strategies(): + """Register default common strategies.""" + register_default_strategy(StrategyAction.LOAD_COMMON, 'torch', 1, TorchCommonLoadStrategy()) + register_default_strategy( + StrategyAction.SAVE_COMMON, 'torch', 1, TorchCommonSaveStrategy('torch', 1) + ) + + +class TorchCommonSaveStrategy(SaveCommonStrategy): + """Common save strategy leveraging native torch save/load.""" + + def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path): + """Save common part of the state dict.""" + if torch.distributed.get_rank() == 0: + torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME) + + def save_sharded_objects( + self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path + ): + """Save sharded objects from the state dict.""" + for sh_obj in nested_values(sharded_objects_state_dict): + if is_main_replica(sh_obj.replica_id): + save_path = checkpoint_dir / f'{sh_obj.unique_key}.pt' + os.makedirs(save_path.parent, exist_ok=True) + torch.save(sh_obj.data, save_path) + + def can_handle_sharded_objects(self): + """This strategy can handle ShardedObjects.""" + return True + + +class TorchCommonLoadStrategy(LoadCommonStrategy): + """Common load strategy leveraging native torch save/load.""" + + def load_common(self, checkpoint_dir: Path): + """Load common (non-sharded) objects state dict from the checkpoint. + + Args: + checkpoint_dir (Path): checkpoint directory + + Returns: + StateDict: state dict with non-sharded objects from the checkpoint + """ + load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME + try: + return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + ckpt_files = [f.name for f in checkpoint_dir.iterdir()] + logger.debug(f'{err_msg}. Checkpoint directory content: {ckpt_files}') + raise CheckpointingException(err_msg) from e + + def load_sharded_objects( + self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path + ): + """Replaces all ShardedObject from a given state dict with values loaded from the + checkpoint. + + Args: + sharded_objects_state_dict (ShardedStateDict): + sharded state dict defining what objects should be loaded. + checkpoint_dir (Path): checkpoint directory + + Returns: + None: sharded state dict is modified in place + """ + + def load_sharded_object(sh_obj: ShardedObject): + sh_obj.data = None + load_path = checkpoint_dir / f'{sh_obj.unique_key}.pt' + try: + loaded_obj = torch.load(load_path, weights_only=False) + except FileNotFoundError as e: + # Backward compatible logic: previously the save format was incorrect + old_load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt') + try: + loaded_obj = torch.load(old_load_path, weights_only=False) + except FileNotFoundError: + err_msg = f'Object shard {load_path} not found' + obj_subdir = checkpoint_dir / sh_obj.key + if obj_subdir.exists(): + obj_files = [f.name for f in obj_subdir.iterdir()] + logger.debug( + f'{err_msg}. Object {sh_obj.key} directory content: {obj_files}' + ) + else: + ckpt_files = [f.name for f in checkpoint_dir.iterdir()] + logger.debug( + f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint' + f' directory content: {ckpt_files}' + ) + raise CheckpointingException(err_msg) from e + return loaded_obj + + return dict_list_map_inplace(load_sharded_object, sharded_objects_state_dict) + + def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: + sharded_metadata = {} + for subdir in checkpoint_dir.iterdir(): + if not subdir.is_dir(): + continue + shard_files = list(subdir.glob('shard_*.pt')) + if not shard_files: + continue + sh_objs = [] + for shard_file in shard_files: + full_key = f'{subdir.name}/{shard_file.stem}' + sh_objs.append(ShardedObject.empty_from_unique_key(full_key)) + + # This is a backward-compatibility fix, where the last global shape is missing in the + # name + if sh_objs[0].global_shape[-1] < 0: + max_last_offset = max(map(lambda sh_obj: sh_obj.global_offset[-1], sh_objs)) + for sh_obj in sh_objs: + sh_obj.global_shape = (*sh_obj.global_shape[:-1], max_last_offset + 1) + + # Update the sharded state dict + for sh_obj in sh_objs: + sharded_metadata[sh_obj.unique_key] = sh_obj + return sharded_metadata + + @property + def can_handle_sharded_objects(self): + """This strategy can handle ShardedObjects.""" + return True + + def check_backend_compatibility(self, loaded_version): + pass + + def check_version_compatibility(self, loaded_version): + pass diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py new file mode 100644 index 0000000000..2e1dda9756 --- /dev/null +++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py @@ -0,0 +1,496 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +""" Storage writer for PyT Distributed format allowing asynchronous save. """ +import dataclasses +import logging +import os +import queue +from functools import partial +from heapq import heappop, heappush +from itertools import chain +from operator import itemgetter +from pathlib import Path +from time import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import psutil +import torch +from torch import multiprocessing as mp +from torch.distributed.checkpoint import FileSystemWriter +from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item +from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType +from torch.distributed.checkpoint.storage import WriteResult +from torch.futures import Future + +from .async_utils import _disable_gc + +logger = logging.getLogger(__name__) + +WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file + +_results_queue = None + + +def _get_write_results_queue(): + global _results_queue + if _results_queue is None: + ctx = mp.get_context('spawn') + _results_queue = ctx.Manager().Queue() + return _results_queue + + +class FileSystemWriterAsync(FileSystemWriter): + """ + Async-enabled implementation of FileSystemWriter using file IO. + + This class doesn't spawn the async process itself, relies on the external async mechanism. + + Flow: + 1. Call `write_data` + 2. Externally start async process with `get_save_function_and_args` function and args + 3. The async function to call is `writer_proxy_func` which calls + `write_preloaded_data` in multiple processes + + After saving is finalized on all ranks: + 4. Call `super().finish` with the results gathered in `self.writer_result` + + Note that step (3) above can also be called synchronously. + + Currently, it's assumed that a separate writer is created for each ckpt save + (intermediate state is stored as writer attributes). + """ + + def __init__(self, *args, separation_hint: Optional[str] = None, **kwargs): + super().__init__(*args, **kwargs) + if not self.single_file_per_rank: + raise NotImplementedError( + 'single_file_per_rank flag not supported for FileSystemWriterAsync' + ) + + self.can_run_decentralized_global_plan: bool = True + + # Intermediate state between preparation and finalization + self.write_buckets: Optional[List[WriteBucket]] = None + self.results_queue: Optional[mp.Queue] = None + self.separation_hint = separation_hint + + def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None: + """ + First stage of async saving. Copy data to CPU and plan the local saving. + + Args: + plan (SavePlan): save plan generated by the PyT Distributed compatible planner + planner (SavePlanner): save planner used to resolve the bytes and tensor data + + Returns: None, but stores the save plan in `self.write_buckets` + """ + storage_plan: _StoragePrefix = plan.storage_data + start = time() + logger.debug(f"thread_count: {self.thread_count}, time: {start}") + if self.separation_hint: + assert ( + self.thread_count > 1 + ), "thread_count must be at least 2 if separation_hint is provided" + bins = self.thread_count // 2 if self.separation_hint is not None else self.thread_count + item_buckets = _split_by_size_and_type(bins, plan.items) + logger.debug(f"bucket_prep, time: {time() - start}") + + start = time() + # move tensors from GPU to CPU before starting async writing + # We do D2H synchronously for now + file_count = 0 + + def gen_file(prefix=""): + nonlocal file_count + file_name = f"{prefix}{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" + file_count += 1 + return file_name + + def _clone_if_needed(ten: torch.Tensor): + """Clone if we detect incontiguous storage for CPU tensors + + Makes sure we perform a `clone` only if we detect incontiguous storage, + so that we don't blow up host memory unnecessarily. + + TODO: For persistent worker, this work should be changed to move the cpu tensor + to shared_memory. + """ + ten = ten.detach() + if ten.device.type != "cpu": + # We do D2H later when the async_request is scheduled for both sync / async + # checkpointing + return ten + is_view = ten.untyped_storage().size() != ten.numel() * ten.itemsize + return ten.clone() if is_view else ten + + # Prepare bytes / tensor data in each bucket, which will be assigned to each writer process + self.write_buckets = [] + for group_name, group_buckets in _split_by_separation_hint( + item_buckets, self.separation_hint + ).items(): + for bucket in group_buckets: + bytes_data = [ + (item, planner.resolve_data(item)) + for item in bucket + if item.type == WriteItemType.BYTE_IO + ] + tensor_data = [ + (item, _clone_if_needed(planner.resolve_data(item))) + for item in bucket + if item.type != WriteItemType.BYTE_IO + ] + if len(bytes_data) > 0 or len(tensor_data) > 0: + file_name = gen_file(prefix=group_name) + self.write_buckets.append( + (self.path / file_name, file_name, (bytes_data, tensor_data)) + ) + + # Check if there is anything to write on this rank + if len(self.write_buckets) > 0: + assert len(self.write_buckets) <= self.thread_count, ( + len(self.write_buckets), + self.thread_count, + ) + self.results_queue = _get_write_results_queue() + else: + self.results_queue = None + end = time() + logger.debug(f"D2H and push, time: {end - start}") + + def get_save_function_and_args(self) -> Tuple[Optional[Callable], Optional[Callable], List]: + """ + Get function that saves the data to storage along with its arguments. + Allows the external caller to apply the save function synchronously or asynchronously. + + Returns: None (if there is nothing to write on this rank) or a tuple of: + 1) the function that saves the data. + 2) the function that stages the GPU tensors to a destination for async checkpointing. + This function should be self-contained. + 3) arguments to that function in 1). + """ + if not self.write_buckets: + return None, None, () + return ( + self.write_preloaded_data_multiproc, + partial(self.preload_tensors, self.write_buckets, True), + [torch.distributed.get_rank(), self.write_buckets, self.results_queue], + ) + + @staticmethod + def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List[WriteBucket]: + """Preload tensors in state_dict to host memory through CPU memory + Args: + write_buckets(List): List of `WriteBucket`, + which includes what to be saved in a checkpoint + non_blocking (bool, optional): knob to enable pinned D2H memcpy. Default is True. + """ + result = [] + + for bucket in write_buckets: + file_name, storage_key, (bytes_data, tensor_data) = bucket + tensor_data = [ + (item, tensor.to("cpu", non_blocking=non_blocking)) for item, tensor in tensor_data + ] + result.append((file_name, storage_key, (bytes_data, tensor_data))) + if non_blocking: + torch.cuda.synchronize() + return result + + @staticmethod + @_disable_gc() + def write_preloaded_data_multiproc( + rank, write_buckets: List[WriteBucket], global_results_queue: mp.Queue + ) -> None: + """ + Performs saving data to storage with multiple processes. + + Starts predefined number of processes and uses 2 queues to make sure the results + are complete: + - local_results_queue - to send the actual results + - count_queue - small queue to mark worker as completed + + Using just one queue disallowed proper exception handling. + + This method is meant to be run in a forked subprocess. + Triggering GC during execution leads to CUDA errors + (cleaning up tensors owned by the parent process). + To prevent this, we disable the GC explicitly for this function with _disable_gc. + + Args: + write_buckets (List[WriteBucket]): write plan + global_results_queue (mp.Queue): mp.Queue to collect Dict[List[WriteResults]] + (or an Exception) from parallel write processes to the main training process + Returns: None + """ + logger = logging.getLogger(__name__) + w_start = time() + write_results_or_exc: Union[dict, Exception] = dict() + ctx = mp.get_context('fork') + local_results_queue = ctx.Queue() + count_queue = ctx.JoinableQueue() + p_list = [] + for i, write_bucket in enumerate(write_buckets): + try: + count_queue.put(i) + p_list.append( + ctx.Process( + target=FileSystemWriterAsync.write_preloaded_data, + args=(i, write_bucket, local_results_queue, count_queue, True), + ) + ) + except Exception as e: + err_msg = f'An error is caught while a proc {i} is created, error: {e}' + logger.error(err_msg) + write_results_or_exc = RuntimeError(err_msg) + + if not isinstance(write_results_or_exc, Exception): + for p in p_list: + p.start() + + logger.debug('FileSystemWriterAsync: collecting worker results...') + + # To make sure all nodes are completed + count_queue.join() + # At this point, all workers completed, so the queue should have exactly + # `len(write_buckets)` items + for proc_idx in range(len(write_buckets)): + try: + local_proc_idx, local_results_or_exc = local_results_queue.get() + except queue.Empty: + write_results_or_exc = RuntimeError( + f'Unexpected empty `local_results_queue`' + f' (got only {proc_idx}/{len(write_buckets)} items)' + ) + break + else: + if isinstance(local_results_or_exc, Exception): + err_msg = ( + f"Local process {local_proc_idx} encountered" + f" an error: {local_results_or_exc}" + ) + logger.error(err_msg) + write_results_or_exc = local_results_or_exc + break + assert isinstance(local_results_or_exc, list), type(local_results_or_exc) + write_results_or_exc[local_proc_idx] = local_results_or_exc + p_list[local_proc_idx].join() + + logger.debug('FileSystemWriterAsync: collected worker results successfully') + + global_results_queue.put(write_results_or_exc) + + w_end = time() + logger.debug(f"{w_end}, rank: {rank}," f" write(sync,parallel): {w_end - w_start}") + + @staticmethod + @_disable_gc() + def write_preloaded_data( + local_proc_idx: int, + write_bucket: WriteBucket, + results_queue: mp.SimpleQueue, + count_queue: mp.JoinableQueue, + use_fsync: bool, + ) -> None: + """ + Performs actual data saving to storage. + + Args: + local_proc_idx (int): index of a local process that performs writing + write_bucket (WriteBucket): data to write to storage + results_queue (mp.Queue): queue to return the write results + to the proxy checkpoint process. + count_queue (mp.JoinableQueue): queue to marks worker task as completed + use_fsync (bool): if True, calls os.fsync at the end of saving + + Returns: None, the write result are put into the `queue` + """ + logger = logging.getLogger(__name__) + logger.debug(f'{local_proc_idx} started') + mem_before = _process_memory() + + local_results = [] + try: + file_name, storage_key, (bytes_data, tensor_data) = write_bucket + with open(file_name, "wb") as stream: + for write_item, data in bytes_data: + local_results.append(_write_item(stream, data, write_item, storage_key)) + + for write_item, tensor in tensor_data: + assert tensor.is_cpu + local_results.append(_write_item(stream, tensor, write_item, storage_key)) + + if use_fsync: + os.fsync(stream.fileno()) + local_output = (local_proc_idx, local_results) + except Exception as e: + logger.debug(f'{local_proc_idx} failed') + local_output = (local_proc_idx, e) + + results_queue.put(local_output) + # Signal this process is done. + count_queue.get() + count_queue.task_done() + + mem_after = _process_memory() + logger.debug( + f"{local_proc_idx} consumed: {mem_after - mem_before}," + f" before: {mem_before}, after: {mem_after}" + ) + + def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]: + """Write all items from ``plan``.""" + raise NotImplementedError('write_data not implemented for FileSystemWriterAsync') + + def retrieve_write_results(self) -> List[WriteResult]: + """ + Turn the latest dict including write results from `self.results_queue` + into a single results lists. Includes error check. + + Returns (List[WriteResult]): the list of write results + from all local processes performing the save. + + """ + assert self.write_buckets is not None + + if self.results_queue is None: + write_results_or_exc = {} + else: + try: + write_results_or_exc = self.results_queue.get_nowait() + except queue.Empty: + raise RuntimeError(f'results_queue should not be empty') + + if isinstance(write_results_or_exc, Exception): + raise RuntimeError(f'Worker failure: {write_results_or_exc}') from write_results_or_exc + write_results: dict = write_results_or_exc + if len(write_results) != len(self.write_buckets): + raise RuntimeError( + f'Incomplete worker results (expected {len(self.write_buckets)},' + f' got {len(write_results)}. This probably indicates a worker failure.' + ) + return list(chain.from_iterable(write_results.values())) + + def prepare_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: + """Instead of assigning indices by plan order, uses PyT rank (same outcome). + + Args: + local_plan (SavePlan): local plan to turn to a global plan + (without interactions with other ranks) + + Returns: + SavePlan - locally transformed plan equivalent to the plan that would be + created by the coordinator + """ + return dataclasses.replace( + local_plan, storage_data=_StoragePrefix(f"__{torch.distributed.get_rank()}_") + ) + + +def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]: + """ + Splits write items according to item size into close to uniform bins. + + Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type, + but with a fixed _item_size function. + + Args: + bins (int): numbers of bins to split to + items (List[WriteItem]): list of write items + + Returns (List[List[WriteItem]]): write items split to bins + """ + if bins == 1: + return [items] + + bytes_items: List[WriteItem] = [] + tensor_items: List[WriteItem] = [] + for wi in items: + container = bytes_items if wi.type == WriteItemType.BYTE_IO else tensor_items + container.append(wi) + + buckets: List[List[WriteItem]] = [[] for _ in range(bins)] + bucket_sizes = [0 for _ in range(bins)] + + # Assign bytes with a simple round-robin + for i, item in enumerate(bytes_items): + buckets[i % bins].append(item) + + # Sort tensor items by size in decreasing order once and store the size with item + sized_tensors = [(item, _item_size(item)) for item in tensor_items] + sized_tensors.sort(key=itemgetter(1), reverse=True) + + # Use a min heap for bin assignment + # Store (total_size_of_bin, bin_index) tuples + heap: List[Tuple[int, int]] = [(0, i) for i in range(bins)] + + # Assign tensors using heap + for item, size in sized_tensors: + total_bin_size, bin_idx = heappop(heap) + buckets[bin_idx].append(item) + heappush(heap, (total_bin_size + size, bin_idx)) + + return buckets + + +def _split_by_separation_hint( + buckets: List[List[WriteItem]], separation_hint: Optional[str] = None +) -> Dict[str, List[List[WriteItem]]]: + """ + Splits buckets into those whose keys begin with the separation_hint and those whose keys do not + + Args: + buckets (List[List[WriteItem]]): buckets to split + separation_hint (Optional[str]): optional prefix to split on + + Returns (Dict[str, List[List[WriteItem]]]): a dictionary + mapping the prefix to the relevant buckets + """ + bins = len(buckets) + buckets_with_separation_hint = {} + if separation_hint is not None: + buckets_default = [[] for _ in range(bins)] + buckets_hint = [[] for _ in range(bins)] + for i in range(bins): + for item in buckets[i]: + if item.index.fqn.startswith(separation_hint): + buckets_hint[i].append(item) + else: + buckets_default[i].append(item) + buckets_with_separation_hint[""] = buckets_default + buckets_with_separation_hint[separation_hint] = buckets_hint + else: + buckets_with_separation_hint[""] = buckets + return buckets_with_separation_hint + + +def _item_size(item: WriteItem) -> int: + """ + Calculates size (in bytes) of a single write item. + + Same as torch.distributed.checkpoint.filesystem._item_size, + but fixes computing chunk size (with item.tensor_data.chunk.sizes) + + Args: + item (WriteItem): write item to compute the size of + + Returns (int): size of an item in bytes + """ + size = 1 + assert item.tensor_data is not None + # can't use math.prod as PT needs to support older python + for s in item.tensor_data.chunk.sizes: + size *= s + + dtype = item.tensor_data.properties.dtype + return size * torch._utils._element_size(dtype) + + +def _process_memory() -> int: + """ + Get memory used by current process. + + Returns (int): memory used by current process + """ + process = psutil.Process(os.getpid()) + mem_info = process.memory_info() + return mem_info.rss diff --git a/megatron/core/dist_checkpointing/strategies/fully_parallel.py b/megatron/core/dist_checkpointing/strategies/fully_parallel.py new file mode 100644 index 0000000000..269a371d49 --- /dev/null +++ b/megatron/core/dist_checkpointing/strategies/fully_parallel.py @@ -0,0 +1,515 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import logging +from pathlib import Path +from time import time +from typing import Any, Callable, Dict, Optional, Tuple, TypeVar + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint import Metadata + +from megatron.core.dist_checkpointing import ShardedObject, ShardedTensor +from megatron.core.dist_checkpointing.core import CheckpointingException +from megatron.core.dist_checkpointing.dict_utils import ( + dict_list_map_inplace, + extract_matching_values, + merge, + nested_values, +) +from megatron.core.dist_checkpointing.exchange_utils import ( + ShardDistribution, + determine_main_replica_uniform_distribution, + exchange_by_distribution, + exchange_loaded_objects_gather_object, +) +from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict, is_main_replica +from megatron.core.dist_checkpointing.strategies.base import ( + AsyncSaveShardedStrategy, + LoadShardedStrategy, + SaveShardedStrategy, +) +from megatron.core.dist_checkpointing.utils import ( + _sharded_object_id, + _sharded_tensor_shard_id, + _ShardId, + debug_time, +) +from megatron.core.dist_checkpointing.validation import ( + determine_global_metadata, + validate_sharding_integrity, +) + +logger = logging.getLogger(__name__) + +T = TypeVar('T', ShardedObject, ShardedTensor) + + +class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy): + """Wraps arbitrary strategy and distributes the save during `save`. + + The save distribution happens without any *data* communication. + Only the *metadata* is exchanged and based on data replication on different + ranks, we try to distribute the save as uniformly as possible. + + This wrapper assumes, that setting `replica_id` to 0 will make the + underlying strategy do the saving on current rank. All the other `replica_id`s + are set to 1. + + Currently, the save distribution is realized with a greedy algorithm + described in `distribute_shards_to_ranks`. + + Args: + strategy (SaveShardedStrategy): base strategy to wrap + parallelization_group (ProcessGroup, optional): process group to use for save + distribution. Note that this doesn't have to match exactly the + data distribution, but should cover the replication pattern + to maximize performance. Defaults to the whole world. + do_cache_distribution (bool, optional): whether to cache the save distribution + from previous calls. Should be set to True only if the state dict + structure between the calls is always the same. Defaults to True. + """ + + def __init__( + self, + strategy: SaveShardedStrategy, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, + do_cache_distribution: bool = False, + ): + super().__init__(strategy.backend, strategy.version) + self.base_strategy = strategy + self.parallelization_group = parallelization_group + self.do_cache_distribution = do_cache_distribution + + self.cached_distribution: Optional[ShardDistribution] = None + + def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + if not isinstance(self.base_strategy, AsyncSaveShardedStrategy): + raise CheckpointingException( + f'Cannot apply async_save to non-async base strategy {self.base_strategy}' + ) + self.apply_saving_parallelization(sharded_state_dict) + return self.base_strategy.async_save(sharded_state_dict, checkpoint_dir) + + def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + self.apply_saving_parallelization(sharded_state_dict) + return self.base_strategy.save(sharded_state_dict, checkpoint_dir) + + def apply_saving_parallelization(self, sharded_state_dict: ShardedStateDict) -> None: + """Distributes the save across ranks by exchanging metadata. + + Exchanges metadata from the state dict and computes the uniform + (as close as possible) distribution of saves among the ranks. + + If `self.do_cache_distribution` is True, caches the distribution between + the calls and subsequent distributions happen without any inter-rank + communication. + + Args: + sharded_state_dict (ShardedStateDict): state dict to distribute the saving + + Returns: None + """ + start = time() + if self.do_cache_distribution and self.cached_distribution is not None: + logger.debug(f'Apply *cached* save parallelization') + precomputed_distribution = self.cached_distribution + else: + logger.debug(f'Apply save parallelization') + precomputed_distribution = determine_main_replica_uniform_distribution( + sharded_state_dict, self.parallelization_group + ) + + distribute_main_replicas_with_precomputed_distribution( + sharded_state_dict, self.parallelization_group, precomputed_distribution + ) + if self.cached_distribution is None: + # First time applying the parallelization + validate_sharding_integrity(determine_global_metadata(sharded_state_dict)[1]) + if self.do_cache_distribution: + self.cached_distribution = precomputed_distribution + end = time() + logger.debug(f"parallel save sharding, time: {end - start}") + + @property + def can_handle_sharded_objects(self): + return self.base_strategy.can_handle_sharded_objects + + +class FullyParallelLoadStrategyWrapper(LoadShardedStrategy): + """Wraps arbitrary load strategy and distributes the load during `load`. + + See `load` method docs for details. + + Args: + strategy (LoadShardedStrategy): base strategy to wrap + parallelization_group (ProcessGroup, optional): process group to use for load + distribution. Note that this doesn't have to match exactly the + data distribution, but should cover the replication pattern + to maximize performance. Defaults to the whole world. + In most cases, it's recommended to set it to the DP group. + do_cache_distribution (bool, optional): whether to cache the load distribution + from previous calls. Should be set to True only if the state dict + structure between the calls is always the same. Defaults to False, + since the loading in general happens only once during training. + Note that the load distribution *cannot* be reused as a save distribution, + because save/load is not fully symmetrical. + exchange_algo (str): algorithm to use for exchanging the data. + Options: + - broadcast - each rank broadcasts individual tensors to others + - gather_object (default) - ranks all_gather_object the whole loaded state dicts + - gather_rounds (default) - ranks all gather individual tensors in rounds + See method docs for more details. + """ + + def __init__( + self, + strategy: LoadShardedStrategy, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, + do_cache_distribution: bool = False, + exchange_algo: str = 'broadcast', + ): + super().__init__() + self.base_strategy = strategy + if parallelization_group is None: + parallelization_group = ( + dist.GroupMember.WORLD + ) # explicit group needed for torch.distributed.get_global_rank call + self.parallelization_group = parallelization_group + self.do_cache_distribution = do_cache_distribution + self.exchange_algo = exchange_algo + + self.cached_distribution: Optional[ShardDistribution] = None + self.cached_global_metadata: Optional[Metadata] = None + + @debug_time("FullyParallelLoadStrategyWrapper.load", logger) + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict: + """Distributes the load and calls underlying strategy only for parts of the state dict. + + Steps: + 1. Load metadata is exchanged between the ranks in the parallelization group. + 2. Each rank deterministically plans the load for the whole workload + so that the loads are as uniform as possible. + 3. Each ranks loads its planned shard of the checkpoint. + 4. All ranks exchange the loaded shards. + + Internode communication is involved in steps (1) (with metadata) + and (4) (with actual data). Storage interaction is involved in step (3). + + Currently, the load distribution (step 2) is realized with a greedy algorithm + described in `distribute_shards_to_ranks` (same as for saving distribution). + + Currently, the shards are all gathered between all ranks in the parallelization + group. This might not be optimal (some ranks do not need all tensors), + but it's a reasonable approximation for an optimal exchange in most scenarios. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to load + checkpoint_dir (Path): checkpoint directory to load from + + Returns: + StateDict: loaded state dict. The state dict should be equivalent to + a state dict that would be loaded with the underlying strategy + without this wrapper. + """ + + loaded_state_dict = {} + + if torch.distributed.get_world_size(self.parallelization_group) <= 1: + return self.base_strategy.load(sharded_state_dict, checkpoint_dir) + + # Step 1 and 2: exchange load metadata and distribute the load + with debug_time("self.apply_loading_parallelization", logger): + precomputed_distribution: ShardDistribution | None = self.apply_loading_parallelization( + sharded_state_dict + ) + assert ( + precomputed_distribution is not None + ), 'Expecting non-trivial distribution for non-trivial parallelization group' + + # Step 3: load part of the checkpoint. + # Load only sharded objects first. ShardedTensors will be loaded separately + # so that we can keep track of sharded tensors loaded by this rank + (sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards) = ( + self._defer_loading_sharded_tensors(sharded_state_dict) + ) + + (sharded_objects, sharded_state_dict, to_load_objects, unloaded_objects) = ( + self._defer_loading_sharded_objects(sharded_state_dict) + ) + + assert ( + len(sharded_state_dict) == 0 + ), "sharded_state_dict is not empty after deferring tensors and objects" + with debug_time("base_load_ShardedObjects", logger): + # Load sharded objects first + loaded_objects = self.base_strategy.load(to_load_objects, checkpoint_dir) + + with debug_time("base_load_ShardedTensors", logger): + # Load sharded tensors separately + loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir) + + with debug_time("self.exchange_loaded_tensors", logger): + + # Step 4: exchange data between ranks + logger.debug(f'Applying parallel load with algo {self.exchange_algo}') + all_loaded_tensors = exchange_by_distribution( + loaded_tensors, + unloaded_shards, + precomputed_distribution, + self.parallelization_group, + self.exchange_algo, + ) + if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()): + missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys() + raise CheckpointingException( + f'Missing shards after fully parallel loading: {missing_shards}' + ) + + with debug_time("torch.cuda.synchronize", logger): + torch.cuda.synchronize() + + all_loaded_objects = exchange_loaded_objects_gather_object(loaded_objects) + + if not set(unloaded_objects.keys()).issubset(all_loaded_objects.keys()): + missing_object_shards = set(unloaded_objects.keys()) - all_loaded_objects.keys() + raise CheckpointingException( + f'Missing object shards after fully parallel loading: {missing_object_shards}' + ) + torch.cuda.synchronize() + + self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors) + self.fill_in_deferred_sharded_objects(sharded_objects, all_loaded_objects) + + merge(loaded_state_dict, sharded_objects) + merge(loaded_state_dict, sharded_tensors) + if hasattr(self.base_strategy, "cached_global_metadata"): + self.cached_global_metadata = self.base_strategy.cached_global_metadata + return loaded_state_dict + + @staticmethod + def _defer_loading_sharded_objects( + sharded_state_dict: ShardedStateDict, + ) -> Tuple[ + ShardedStateDict, + ShardedStateDict, + Dict[_ShardId, ShardedObject], + Dict[_ShardId, ShardedObject], + ]: + return _defer_loading_sharded_items(sharded_state_dict, ShardedObject, _sharded_object_id) + + @staticmethod + def _defer_loading_sharded_tensors( + sharded_state_dict: ShardedStateDict, + ) -> Tuple[ + ShardedStateDict, + ShardedStateDict, + Dict[_ShardId, ShardedTensor], + Dict[_ShardId, ShardedTensor], + ]: + return _defer_loading_sharded_items( + sharded_state_dict, ShardedTensor, _sharded_tensor_shard_id + ) + + @staticmethod + def fill_in_deferred_sharded_objects( + sharded_state_dict: ShardedStateDict, loaded_objects: Dict[_ShardId, Any] + ) -> None: + """Fill in objects not loaded by current rank with objects from `loaded_objects` map. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to fill in. + ShardedObjects are completely replaced with corresponding objects. + loaded_objects (Dict[_ShardId, Any]): dict allowing to map + ShardedObject from the sharded_state_dict to loaded objects. + + Returns: + None + """ + _fill_in_deferred_sharded_items( + sharded_state_dict, loaded_objects, ShardedObject, _sharded_object_id + ) + + @staticmethod + def fill_in_deferred_sharded_tensors( + sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor] + ) -> None: + """Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to fill in. + ShardedTensors are completely replaced with corresponding torch.Tensors. + loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map + ShardedTensor from the sharded_state_dict to loaded tensors. + + Returns: + None + """ + _fill_in_deferred_sharded_items( + sharded_state_dict, loaded_tensors, ShardedTensor, _sharded_tensor_shard_id + ) + + def apply_loading_parallelization( + self, sharded_state_dict: ShardedStateDict + ) -> Optional[ShardDistribution]: + """Distributes the load across ranks by exchanging metadata. + + Exchanges metadata from the state dict and computes the uniform + (as close as possible) distribution of loads among the ranks. + Marks ShardedTensors to be loaded by the current rank with replica_id 0 + (and others with non 0 values). + + If `self.do_cache_distribution` is True, caches the distribution between + the calls and subsequent distributions happen without any inter-rank + communication. + + Args: + sharded_state_dict (ShardedStateDict): state dict to distribute the loading + + Returns: + ShardDistribution (optional): the computed loading distribution + """ + if self.do_cache_distribution and self.cached_distribution is not None: + logger.debug(f'Apply *cached* load parallelization') + precomputed_distribution = self.cached_distribution + else: + logger.debug(f'Apply load parallelization') + precomputed_distribution = determine_main_replica_uniform_distribution( + sharded_state_dict, self.parallelization_group, True + ) + + distribute_main_replicas_with_precomputed_distribution( + sharded_state_dict, self.parallelization_group, precomputed_distribution + ) + if self.do_cache_distribution: + self.cached_distribution = precomputed_distribution + + return precomputed_distribution + + @property + def can_handle_sharded_objects(self): + return self.base_strategy.can_handle_sharded_objects + + def load_tensors_metadata(self, checkpoint_dir: Path): + return self.base_strategy.load_tensors_metadata(checkpoint_dir) + + def load_sharded_metadata(self, checkpoint_dir: Path): + return self.base_strategy.load_sharded_metadata(checkpoint_dir) + + def check_backend_compatibility(self, loaded_version): + return self.base_strategy.check_backend_compatibility(loaded_version) + + def check_version_compatibility(self, loaded_version): + return self.base_strategy.check_version_compatibility(loaded_version) + + +def distribute_main_replicas_with_precomputed_distribution( + sharded_state_dict: ShardedStateDict, + parallelization_group: torch.distributed.ProcessGroup, + precomputed_distribution: Optional[ShardDistribution], +): + """Applies the save distribution computed with `determine_main_replica_uniform_distribution`. + + Based on rank assignment, sets replica ids of the shards saved by current rank to 0 + and all the other replica ids to 1. + + Args: + sharded_state_dict (ShardedStateDict): state dict to apply the save distribution to + parallelization_group (ProcessGroup): distribution will be applied within this + process group. Must match with the process group passed to + `determine_main_replica_uniform_distribution`. + precomputed_distribution (ShardDistribution): distribution computed with + `determine_main_replica_uniform_distribution` + + Returns: None + + Example replica ids of tensors A, B, C before distribution: + rank0: A: (0, 0, 0), B: (0, 0, 0), C: (0, 0, 0) + rank1: A: (0, 0, 1), B: (0, 0, 1), C: (0, 0, 1) + rank2: A: (0, 0, 2), B: (0, 0, 2), C: (0, 0, 2) + + Replicas after distribution for the example above: + rank0: A: 0, B: 1, C: 1 + rank1: A: 1, B: 0, C: 1 + rank2: A: 1, B: 1, C: 0 + """ + if torch.distributed.get_world_size(group=parallelization_group) <= 1: + return + if precomputed_distribution is None: + raise ValueError( + 'precomputed_distribution must be not None for non-trivial parallelization group' + ) + + local_shards = list( + sh_base + for sh_base in nested_values(sharded_state_dict) + if isinstance(sh_base, ShardedTensor) + ) + + rank_within_dp_group = torch.distributed.get_rank(parallelization_group) + for sh_ten in local_shards: + shard_id = _sharded_tensor_shard_id(sh_ten) + if ( + shard_id in precomputed_distribution.shards_in_this_group + and rank_within_dp_group == precomputed_distribution.main_rank_for_shard[shard_id] + ): + sh_ten.replica_id = 0 + else: + sh_ten.replica_id = 1 + + +def _defer_loading_sharded_items( + sharded_state_dict: ShardedStateDict, item_type: type, shard_id_func: Callable[[T], _ShardId] +) -> Tuple[ShardedStateDict, ShardedStateDict, Dict[_ShardId, T], Dict[_ShardId, T]]: + """Divides state dict into parts loaded by this vs other ranks. + + Args: + sharded_state_dict (ShardedStateDict): state dict with sharded items + that will be divided. + item_type: The type of sharded item (ShardedObject or ShardedTensor) + shard_id_func: Function to get the shard ID for the item type + + Returns: a tuple of: + - ShardedStateDict: sub-state dict only with sharded items + - ShardedStateDict: sub-state dict with non-sharded items + - Dict[_ShardId, T]: mapping from shard id to items loaded by *this* rank + - Dict[_ShardId, T]: mapping from shard id to items loaded by *other* ranks + """ + to_load_shards = {} + unloaded_shards = {} + + sharded_items, remaining_state_dict = extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, item_type) + ) + + def wrap_non_main_replicas(x: Any) -> Any: + if isinstance(x, item_type): + shard_id = shard_id_func(x) + if is_main_replica(x.replica_id): + to_load_shards[shard_id] = x + else: + unloaded_shards[shard_id] = x + return x + + dict_list_map_inplace(wrap_non_main_replicas, sharded_items) + return sharded_items, remaining_state_dict, to_load_shards, unloaded_shards + + +def _fill_in_deferred_sharded_items( + sharded_state_dict: ShardedStateDict, + loaded_items: Dict[_ShardId, Any], + item_type: type, + shard_id_func: Callable[[T], _ShardId], +) -> None: + """Helper function to fill in items not loaded by current rank.""" + + def fill_in_sharded_item(x: Any) -> Any: + if isinstance(x, item_type): + try: + x = loaded_items[shard_id_func(x)] + except KeyError as e: + raise CheckpointingException( + f'Missing loaded item shard: {shard_id_func(x)}' + ) from e + return x + + dict_list_map_inplace(fill_in_sharded_item, sharded_state_dict) diff --git a/megatron/core/dist_checkpointing/strategies/resharding.py b/megatron/core/dist_checkpointing/strategies/resharding.py new file mode 100644 index 0000000000..9cf5884455 --- /dev/null +++ b/megatron/core/dist_checkpointing/strategies/resharding.py @@ -0,0 +1,318 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" Performant resharding of flattened tensors. + +Tensors that are first sharded (e.g. across TP) and then flattened cause +very irregular access patterns during loading. The idea for performant save/load +is to store tensors with global shape [X, Y, Z] and local shape [x, y, z] +as tensors with global shape [X // x, Y // y, Z // z, x * y * z] and +local shape [1, 1, 1, x * y * z]. This allows parallel save of tensors along the +last (flattened) dimension. During loading, some additional resharding is needed. +""" +import logging +import math +from dataclasses import dataclass +from itertools import product +from typing import Any, Dict, Tuple, Union + +import numpy as np +import torch +from torch.distributed.checkpoint import ChunkStorageMetadata +from torch.distributed.checkpoint.resharding import _shards_get_overlap_region_wrt_saved_tensor + +from megatron.core.dist_checkpointing import ShardedTensor +from megatron.core.dist_checkpointing.core import CheckpointingException +from megatron.core.dist_checkpointing.dict_utils import ( + dict_list_map_inplace, + extract_matching_values, +) +from megatron.core.dist_checkpointing.mapping import ( + ShardedStateDict, + ShardedTensorFactory, + StateDict, + apply_factories, + apply_factory_merges, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class TensorReformulationMetadata: + """Metadata needed to restore the original tensor shape. + + Args: + ckpt_orig_global_shape (Tuple[int, ...]): original global shape of the tensor + saved in the checkpoint. This is the global shape of the application, + further reformulated into `ckpt_reform_global_shape` while saving. + ckpt_reform_global_shape (Tuple[int, ...]): reformulated global shape of the tensor + saved in the checkpoint. This is the actual saved shape. + """ + + ckpt_orig_global_shape: Tuple[int, ...] + ckpt_reform_global_shape: Tuple[int, ...] + + def __post_init__(self): + assert self.ckpt_orig_global_shape + + +def nd_flattened_tensor_reformulated_global_shape(sh_ten: ShardedTensor) -> Tuple[int, ...]: + """Reformulated global shape of the flattened N-D ShardedTensor. + + N-D tensor global shape [X, Y, Z] and local shape [x, y, z] + is reformulated into global shape [X // x, Y // y, Z // z, x * y * z] and + local shape [1, 1, 1, x * y * z], to allow parallel save of tensors along the + last (flattened) dimension. + + Args: + sh_ten (ShardedTensor): flattened N-D ShardedTensor (N > 1) + + Returns: + Tuple[int, ...]: reformulated tensor shape + """ + assert is_nd_flattened_tensor(sh_ten), sh_ten + return sh_ten.axis_fragmentations + (int(np.prod(sh_ten.local_shape)),) + + +def is_nd_flattened_tensor(sh_ten: Any) -> bool: + """Checks if ShardedTensor is flattened and more than 1-dimensional + + Args: + sh_ten (Any): any object + + Returns: + bool: whether the given object is a flattened ShardedTensor and is N-dimensional (N > 1) + """ + return isinstance(sh_ten, ShardedTensor) and sh_ten.flattened_range is not None + + +# information needed to restore. With current implementation, this is a nested state dict +# with ShardedTensorFactories which is basically a ShardedStateDict type +ReformulationRestoreMetadata = ShardedStateDict + + +def apply_nd_flattened_tensors_reformulation( + sharded_state_dict: ShardedStateDict, + reformulation_metadata: Dict[str, TensorReformulationMetadata], +) -> Tuple[ShardedStateDict, ReformulationRestoreMetadata]: + """Applies N-D reformulation to a given sharded state dict. + + After applying the method and loading the reformulated state dict, + the `restore_nd_flattened_tensors_formulation` needs to be applied. + + Current implementation uses ShardedTensorFactories for convenience of + restoring the original structure, but it's just an implementation detail. + Turns N-D ShardedTensors into factories and immediately applies them, + keeping the data needed to restore the original structure. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict potentially + with tensors to reformulate. + reformulation_metadata (Dict[str, TensorReformulationMetadata]): dict + containing all metadata needed for reformulating tensors in `sharded_state_dict`. + for each N-D flattened tensor `sh_ten` in `sharded_state_dict` there must be an + entry with `sh_ten.key`. + + Returns: + tuple: + ShardedStateDict - reformulated sharded state dict + ReformulationRestoreMetadata - data needed to restore the original formulation + with `restore_nd_flattened_tensors_formulation` + """ + + def maybe_reformulate_nd_flattened_tensor(sh_ten: Any): + if not isinstance(sh_ten, ShardedTensor) or not is_nd_flattened_tensor(sh_ten): + return sh_ten + # N-D flattened ShardedTensor + try: + sh_ten_reformulation_metadata = reformulation_metadata[sh_ten.key] + except KeyError as e: + # Handle legacy checkpointing where 1-D flatten tensor metadata was not saved + if len(sh_ten.global_shape) == 1: + return sh_ten + raise CheckpointingException( + f'Missing reformulation metadata for tensor {sh_ten}. ' + f'Existing keys: {reformulation_metadata.keys()}' + ) from e + + ckpt_actual_saved_shape = sh_ten_reformulation_metadata.ckpt_reform_global_shape + app_actual_load_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten) + if ckpt_actual_saved_shape == app_actual_load_shape: + # Same shape - no need to reshard + return sh_ten + + return reformulate_single_nd_flattened_tensor(sh_ten, sh_ten_reformulation_metadata) + + # Turn N-D tensors into factories and immediately apply them + dict_list_map_inplace(maybe_reformulate_nd_flattened_tensor, sharded_state_dict) + sh_ten_factories, _ = extract_matching_values( + sharded_state_dict, + lambda x: isinstance(x, ShardedTensorFactory), + return_lists_as_dicts=True, + ) + apply_factories(sharded_state_dict) + + # Unlink `data` pointers to free memory + def unlink_data(x): + x.data = None + return x + + dict_list_map_inplace(unlink_data, sh_ten_factories) + return sharded_state_dict, sh_ten_factories + + +def restore_nd_flattened_tensors_formulation( + state_dict: StateDict, formulation_restore_metadata: ReformulationRestoreMetadata +) -> StateDict: + """Restores the original state dict from a reformulated form. + + Inverse of `apply_nd_flattened_tensors_reformulation`. + + Args: + state_dict (StateDict): state dict obtained by loading a reformulated + sharded state dict. + formulation_restore_metadata (ReformulationRestoreMetadata): metadata returned by + `apply_nd_flattened_tensors_reformulation` function + + Returns: + StateDict: state dict with the original tensors formulation restored + """ + return apply_factory_merges(state_dict, formulation_restore_metadata) + + +def reformulate_single_nd_flattened_tensor( + sh_ten: ShardedTensor, reformulation_metadata: TensorReformulationMetadata +) -> Union[Any, ShardedTensorFactory]: + """Reformulates shapes of a single N-D flattened ShardedTensor. + + We need to define a pair of transformations: + - turn N-D ShardedTensor with original formulation into multiple reformulated ShardedTensors + - merge multiple reformulated loaded torch.Tensors into a single original tensor + Current implementation uses ShardedTensorFactories as a convenient mechanism + for specifying and keeping track of those transformations. + + Args: + sh_ten (ShardedTensor): sharded tensor to reformulate. + reformulation_metadata (TensorReformulationMetadata): metadata needed to + perform the reformulation + + Returns: + ShardedTensorFactory: factory that keeps information how to reformulate + (build) the ShardedTensor and then restore original formulation (merge) + after loading. + """ + rmd = reformulation_metadata + # Data won't be needed - remove unnecessary tensor references + sh_ten = sh_ten.without_data() + + # Based on reformulation_metadata, determine other tensor shapes and metadata + ckpt_axis_fragmentation = rmd.ckpt_reform_global_shape[:-1] + for sh, fragm in zip(rmd.ckpt_orig_global_shape, ckpt_axis_fragmentation): + assert sh % fragm == 0, (sh_ten, rmd.ckpt_reform_global_shape) + ckpt_local_shape_with_prepended_axis = tuple( + sh // fragm for sh, fragm in zip(rmd.ckpt_orig_global_shape, ckpt_axis_fragmentation) + ) + assert ( + ckpt_local_shape_with_prepended_axis[: sh_ten.prepend_axis_num] + == (1,) * sh_ten.prepend_axis_num + ), (ckpt_local_shape_with_prepended_axis, sh_ten) + ckpt_local_shape = ckpt_local_shape_with_prepended_axis[sh_ten.prepend_axis_num :] + + # Iterate over reformulated shapes needed by the application and from checkpoint, + # and generate new ShardedTensors that match the checkpoint sharding. + overlap_dim_offsets = [] + assert len(ckpt_axis_fragmentation) == len(sh_ten.axis_fragmentations), ( + ckpt_axis_fragmentation, + sh_ten, + ) + for dim, (app_chunk_dim_offset, ckpt_fragm, app_fragm) in enumerate( + zip( + sh_ten.local_chunk_offset_in_global(), + ckpt_axis_fragmentation, + sh_ten.axis_fragmentations, + ) + ): + # without `int`, it's an exact offset of the app shard expressed in ckpt_local_shape units + first_overlap_dim_offset = int(ckpt_fragm / app_fragm * app_chunk_dim_offset) + # `math.ceil` argument is an exact offset of the app next shard expressed + # in ckpt_local_shape units + next_overlap_dim_offset = math.ceil(ckpt_fragm / app_fragm * (app_chunk_dim_offset + 1)) + overlap_dim_offsets.append(range(first_overlap_dim_offset, next_overlap_dim_offset)) + + logger.debug( + f'Generated the following number of overlap shards for each dimension: ' + f'{list(map(len, overlap_dim_offsets))} for fragmentation ckpt ' + f'{ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} ' + f'and chunk offset {sh_ten.local_chunk_offset_in_global()}' + ) + reformulated_sh_tens = {} + for chunk_offset in product(*overlap_dim_offsets): + global_offset = tuple( + chunk_off * chunk_shape + for chunk_off, chunk_shape in zip(chunk_offset, ckpt_local_shape_with_prepended_axis) + ) + reformulated_sh_tens[(global_offset, ckpt_local_shape)] = ShardedTensor( + sh_ten.key, + None, + sh_ten.dtype, + ckpt_local_shape, + rmd.ckpt_orig_global_shape, + global_offset, + ckpt_axis_fragmentation, + sh_ten.replica_id, + sh_ten.prepend_axis_num, + sh_ten.allow_shape_mismatch, + flattened_range=slice(0, rmd.ckpt_reform_global_shape[-1]), # whole ckpt shard + ) + + # Now, we have to define the transformations from application sharding + # to checkpoint sharding. + + @torch.no_grad() + def sh_ten_build_fn(*args, **kwargs): + # Here we simply return the precomputed tensors. + return reformulated_sh_tens + + @torch.no_grad() + def sh_ten_merge_fn(sub_state_dict): + # This is the non-flattened local tensor with original formulation + # that we are going to fill with shards loaded from the checkpoint. + app_non_flat_ten = torch.empty( + sh_ten.local_shape, + dtype=sh_ten.dtype, + device=sh_ten.data.device if sh_ten.data is not None else None, + ) + + assert len(sub_state_dict) > 0 + for (ckpt_global_offset, ckpt_local_shape), ckpt_ten in sub_state_dict.items(): + # For each ckpt shard, we fill the appropriate application shard part + dest_ten = app_non_flat_ten + src_ten = ckpt_ten.view(ckpt_local_shape) + # We don't need narrowing over `prepend_axis_num` axes so we take + # the [sh_ten.prepend_axis_num:] offsets slice + for ( + dim, + offset_for_saved_tensor, + offset_for_current_tensor, + length, + ) in _shards_get_overlap_region_wrt_saved_tensor( + saved_shard=ChunkStorageMetadata( + ckpt_global_offset[sh_ten.prepend_axis_num :], ckpt_local_shape + ), + current_shard=ChunkStorageMetadata( + sh_ten.global_offset[sh_ten.prepend_axis_num :], sh_ten.local_shape + ), + ): + src_ten = src_ten.narrow(dim, offset_for_saved_tensor, length) + dest_ten = dest_ten.narrow(dim, offset_for_current_tensor, length) + dest_ten.copy_(src_ten) + return app_non_flat_ten.flatten()[sh_ten.flattened_range] + + return ShardedTensorFactory( + sh_ten.key, + sh_ten.data, + sh_ten_build_fn, + sh_ten_merge_fn, + sh_ten.replica_id, + sh_ten.flattened_range, + ) diff --git a/megatron/core/dist_checkpointing/strategies/state_dict_saver.py b/megatron/core/dist_checkpointing/strategies/state_dict_saver.py new file mode 100644 index 0000000000..65c394b9ba --- /dev/null +++ b/megatron/core/dist_checkpointing/strategies/state_dict_saver.py @@ -0,0 +1,247 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +""" State dict saver for PyT Distributed format allowing asynchronous save. """ + +from logging import getLogger +from time import time +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint import CheckpointException +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata +from torch.distributed.checkpoint.planner import SavePlan, SavePlanner +from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict + +if TYPE_CHECKING: + from .filesystem_async import FileSystemWriterAsync + from .torch import MCoreSavePlanner + + +logger = getLogger(__name__) + +from dataclasses import fields + + +def _compare_dataclasses(obj1, obj2): + if type(obj1) != type(obj2): + return f"Objects are of different types: {type(obj1)} and {type(obj2)}" + + differences = [] + for field in fields(obj1): + value1 = getattr(obj1, field.name) + value2 = getattr(obj2, field.name) + if value1 != value2: + differences.append(f"{field.name}: {value1} != {value2}") + + return differences if differences else "All fields are equal" + + +def save_state_dict_async_plan( + state_dict: STATE_DICT_TYPE, + storage_writer: 'FileSystemWriterAsync', + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + planner: Optional[Union[SavePlanner, 'MCoreSavePlanner']] = None, + cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None, + loaded_all_plans: Optional[List[SavePlan]] = None, +) -> Tuple[Tuple['FileSystemWriterAsync', Union[Metadata, None], _DistWrapper], SavePlan, bool]: + """ + First stage of saving a state dict to storage. + + This is an async adjustment of torch.distributed.checkpoint.state_dict_saver. + In order to support async save, saving should be split into three parts: + 1. Planning + 2. Actual saving + 3. Finalization + + Out of these, step (2) *must* happen asynchronously. + The first step is realized with this function. + + The planning part consists of several steps, described here: + https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner + + Args: + state_dict (STATE_DICT_TYPE): state dict to save + storage_writer (FileSystemWriterAsync): in current version only an instance of + FileSystemWriterAsync + process_group (dist.ProcessGroup, optional): process group used for save planning + coordinator_rank (int, optional): coordinator rank for planning. Defaults to 0. + planner (SavePlanner, optional): save planner for torch.distributed.checkpoint format + cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], Optional): + Each object of this tuple will be used in the order as following + cached_central_plan (SavePlan): a globally coordinated save plan + cached in the previous iteration + cached_local_plan (SavePlan): a local plan + cached in the previous iteration + validated_cache_reuse (bool): boolean value to tell global_metadata and planning dict + is consistent over iterations + + Returns: Tuple of: + - storage writer (the one passed as input) + - metadata from planning (or None if we reuse cached global metadata) + - distributed wrapper used for planning + The return value of this function should be passed as an input to + `save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning. + """ + cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False) + if cached_ckpt_structure: + cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure + + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + dist_wrapper = _DistWrapper(process_group, True, coordinator_rank) + if planner is None: + planner = DefaultSavePlanner() + assert planner is not None + + global_metadata = None + logger.debug(f"rank: {rank}, starting state dict save") + local_plan = cached_local_plan + global_md_verify_reuse = False + + def local_step(): + nonlocal local_plan + assert planner is not None + # PyTorch 2.4 introduced additional `metadata` argument, + # we have to reference `is_coordinator` args by name + planner.set_up_planner(state_dict, is_coordinator=dist_wrapper.is_coordinator) + storage_writer.set_up_storage_writer(dist_wrapper.is_coordinator) + if not validated_cache_reuse and local_plan is None: + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) + return local_plan + + def global_step(all_local_plans): + nonlocal global_metadata + assert planner is not None + all_local_plans, global_metadata = planner.create_global_plan(all_local_plans) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + return all_local_plans + + # Execute local and global planning + # Ideally we want to use the cached plan. Otherwise if the planner and storage_writer + # allow it (`can_run_decentralized_global_plan`) we gather the plans to create + # the metadata but prepare the plans independently on each rank. + # In the worst case we have to reduce_scatter all the plans. + start_plan = time() + if validated_cache_reuse and cached_central_plan: + logger.debug(f"rank: {rank}, Passed cache reusable") + local_step() + central_plan = cached_central_plan + elif getattr(planner, 'can_run_decentralized_global_plan', False) and getattr( + storage_writer, 'can_run_decentralized_global_plan', False + ): + local_plan = local_step() + global_md_verify_reuse = verify_global_md_reuse( + loaded_all_plans, local_plan, rank, dist_wrapper + ) + + if not loaded_all_plans or not global_md_verify_reuse: + all_local_plans = dist_wrapper.gather_object(local_plan) + if dist_wrapper.is_coordinator: + _, global_metadata = planner.create_global_plan(all_local_plans) + global_metadata.all_local_plans = all_local_plans + else: + logger.debug(f"rank: {rank}, Passed cached global metadata") + global_metadata = None + local_plan = planner.create_decentralized_global_plan(local_plan) + local_plan = storage_writer.prepare_decentralized_global_plan(local_plan) + central_plan = local_plan + else: + central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step) + central_plan = planner.finish_plan(central_plan) + end_plan = time() + logger.debug(f"rank: {rank}, plan time: {end_plan - start_plan}") + # Prepare async writing of tensors. + # The `storage_writer` will store the information about tensors it needs to save + start = time() + storage_writer.prepare_write_data(central_plan, planner) + end = time() + logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}") + return ( + (storage_writer, global_metadata, dist_wrapper), + central_plan, + local_plan, + cached_central_plan == central_plan, + global_md_verify_reuse, + ) + + +def verify_global_md_reuse( + loaded_all_plans: List[SavePlan], local_plan: SavePlan, rank: int, dist_wrapper: _DistWrapper +) -> bool: + """ + Verifies that global metadata reuse is possible by checking the loaded plans from the + checkpoint are consistent, which means we have the same settings when resuming training. + Args: + loaded_all_plans: List[SavePlan], The loaded plans from the checkpoint + (stored in checkpoint metadata). + local_plan: SavePlan, The local save plan. + rank: Current process rank. + dist_wrapper (_DistWrapper): distributed wrapper created during planning + + Returns: True iff the global metadata reuse is possible. + + """ + logger.debug(f"verifying reuse of global metadata") + if not loaded_all_plans: + global_md_verify_reuse = False + logger.debug("loaded global metadata reuse verification: no loaded plans passed") + + elif len(loaded_all_plans) == dist_wrapper.get_world_size(): + local_verify_reuse = all( + getattr(local_plan, f.name) == getattr(loaded_all_plans[rank], f.name) + for f in fields(local_plan) + if f.name != 'storage_data' + ) + + if not local_verify_reuse: + logger.debug( + f"local_verify_reuse is False: diffs -" + f" {_compare_dataclasses(local_plan, loaded_all_plans[rank])}" + ) + all_results = torch.tensor([local_verify_reuse], dtype=torch.int, device='cuda') + torch.distributed.all_reduce(all_results, op=torch.distributed.ReduceOp.MIN) + # Check if all reduced results are True + global_md_verify_reuse = all_results.item() == 1 + else: + global_md_verify_reuse = False + return global_md_verify_reuse + + +def save_state_dict_async_finalize( + storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper +) -> None: + """ + Finalization of save_state_dict_async_plan. + + The input arguments are the same as the save_state_dict_async_plan output, + the `write_results` are retrieved from the storage_writer. + + Args: + storage_writer (FileSystemWriterAsync): storage writer used for planning + global_metadata (Metadata): metadata created during planning + dist_wrapper (_DistWrapper): distributed wrapper created during planning + + Returns: None + """ + write_results = storage_writer.retrieve_write_results() + + # Gather the write results that will be saved to the metadata file. + gather_start = time() + all_results = dist_wrapper.gather_object(write_results) + gather_end = time() + logger.debug(f"{gather_end}, {torch.distributed.get_rank()}, gather: {gather_end-gather_start}") + + # Store the metadata on coordinator rank + if dist_wrapper.is_coordinator: + node_failures = _get_failure_dict(all_results) + if len(node_failures) == 0: + assert global_metadata is not None + write_start = time() + storage_writer.finish(global_metadata, all_results) + write_end = time() + logger.debug(f"{write_end}, metadata_write: {write_end - write_start}") + else: + raise CheckpointException("write", node_failures) diff --git a/megatron/core/dist_checkpointing/strategies/tensorstore.py b/megatron/core/dist_checkpointing/strategies/tensorstore.py index 4a619353a1..0b20bf5e77 100644 --- a/megatron/core/dist_checkpointing/strategies/tensorstore.py +++ b/megatron/core/dist_checkpointing/strategies/tensorstore.py @@ -13,19 +13,22 @@ from ..core import CheckpointingException from ..dict_utils import dict_list_map_inplace from ..mapping import ShardedStateDict, ShardedTensor -from .base import LoadShardedStrategy, StrategyAction, default_strategies -from .zarr import ( - load_zarr_based_sharded_metadata, - numpy_to_torch_dtype_dict, - postprocess_numpy_array, -) - -_import_trigger = None +from .base import LoadShardedStrategy, StrategyAction, register_default_strategy +from .zarr import load_zarr_based_sharded_metadata, postprocess_numpy_array logger = getLogger(__name__) +def register_default_tensorstore_strategies(): + """Register default strategies leveraging tensorstore.""" + register_default_strategy( + StrategyAction.LOAD_SHARDED, 'zarr', 1, TensorStoreLoadShardedStrategy() + ) + + class TensorStoreLoadShardedStrategy(LoadShardedStrategy): + """Load strategy for Zarr backend using `tensorstore` for loading.""" + def __init__(self, load_directly_on_device: bool = False): super().__init__() self.load_directly_on_device = load_directly_on_device @@ -58,6 +61,8 @@ def check_version_compatibility(self, loaded_version): def merge_global_slice_with_shape(global_slice, actual_shape, key): + """Intersects the global slice with the actual shape (prevent overflow).""" + def _merge_slice(dim_slice, dim_size): if isinstance(dim_slice, slice): assert ( @@ -111,21 +116,13 @@ def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path): def open_ts_array(arr_path: Path): """Opens a Zarr file array with Tensorstore with basic setting. - Arguments: + Args: arr_path (Path): path to a Zarr (Tensorstore) array """ spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}} - spec['kvstore'] = { - 'driver': 'file', - 'path': str(arr_path), - } + spec['kvstore'] = {'driver': 'file', 'path': str(arr_path)} try: arr = ts.open(ts.Spec(spec), open=True).result() except Exception as e: raise CheckpointingException(f'Array {arr_path} could not be loaded. Error: {e}') from e return arr - - -default_strategies[StrategyAction.LOAD_SHARDED.value][ - ('zarr', 1) -] = TensorStoreLoadShardedStrategy() diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py new file mode 100644 index 0000000000..8982014dce --- /dev/null +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -0,0 +1,1010 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Strategies using PyTorch distributed.checkpoint as an underlying format. """ +import io +import os +import pickle +import warnings +from collections import ChainMap, defaultdict +from dataclasses import dataclass +from itertools import product +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast + +import torch +from packaging.version import Version as PkgVersion +from torch.distributed import checkpoint +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.sharded_tensor import Shard +from torch.distributed._shard.sharded_tensor import ShardedTensor as TorchShardedTensor +from torch.distributed._shard.sharded_tensor import ShardedTensorMetadata, TensorProperties +from torch.distributed.checkpoint import ( + BytesStorageMetadata, + DefaultLoadPlanner, + DefaultSavePlanner, + FileSystemReader, + FileSystemWriter, + LoadPlan, + Metadata, + ReadItem, + SavePlan, + TensorStorageMetadata, + WriteItem, +) +from torch.distributed.checkpoint._nested_dict import FLATTEN_MAPPING, unflatten_state_dict +from torch.distributed.checkpoint._traverse import OBJ_PATH, traverse_state_dict +from torch.distributed.checkpoint.metadata import Metadata +from torch.distributed.checkpoint.planner_helpers import _create_write_items + +from ...utils import get_torch_version, is_torch_min_version +from ..core import CheckpointingException +from ..dict_utils import nested_values +from ..mapping import ( + ShardedBase, + ShardedObject, + ShardedStateDict, + ShardedTensor, + StateDict, + is_main_replica, +) +from .async_utils import AsyncRequest +from .base import ( + AsyncSaveShardedStrategy, + LoadShardedStrategy, + StrategyAction, + register_default_strategy, +) +from .cached_metadata_filesystem_reader import CachedMetadataFileSystemReader +from .filesystem_async import FileSystemWriterAsync +from .resharding import ( + TensorReformulationMetadata, + apply_nd_flattened_tensors_reformulation, + is_nd_flattened_tensor, + nd_flattened_tensor_reformulated_global_shape, + restore_nd_flattened_tensors_formulation, +) +from .state_dict_saver import save_state_dict_async_finalize, save_state_dict_async_plan + +try: + if not torch.cuda.is_available(): + raise ImportError + from transformer_engine.pytorch.float8_tensor import Float8Tensor + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + from torch.distributed._tensor import DTensor + + HAVE_DTENSOR = True +except ImportError: + HAVE_DTENSOR = False + +_metadata_fn: str = ".metadata" + + +def register_default_torch_strategies(): + """Register default strategies related to PyT Distributed backend.""" + register_default_strategy( + StrategyAction.LOAD_SHARDED, 'torch_dist', 1, TorchDistLoadShardedStrategy() + ) + register_default_strategy( + StrategyAction.SAVE_SHARDED, 'torch_dist', 1, TorchDistSaveShardedStrategy('torch_dist', 1) + ) + + +logger = getLogger(__name__) + + +def flatten_state_dict( + state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, Dict[str, OBJ_PATH]]: + """Flattens state dict into a single level dict. + + It's a copy of torch.distributed.checkpoint._nested_dict.flatten_state_dict + which also accepts ShardedBase tensors as terminal objects + + Args: + state_dict (ShardedStateDict): state dict to be flattened + + Returns (tuple): flattened state dict and a mapping allowing to recreate the original one + + """ + flattened = {} + mappings = {} + + def flat_copy(path: OBJ_PATH, value: Any) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + traverse_state_dict(state_dict, flat_copy, lambda x: isinstance(x, (torch.Tensor, ShardedBase))) + return flattened, mappings + + +def sharded_tensor_to_torch_sharded_tensor( + sh_tens: List[ShardedTensor], + rank: Optional[int] = None, + load_legacy_1d_flatten_tensors: bool = False, +) -> TorchShardedTensor: + """Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks. + + On high-level, this function follows the logic of + torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor. + Additionally, it saves `prepend_axis_num` and `has_flattened_range` (specific to MCore) + as attributes for further restoration in `_unwrap_pyt_sharded_tensor`. + + NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor. + The only local irregularities could be introduced with a `flattened_range` attribute. + + This function handles 2 different type of ShardedTensors: + 1. Non-flat regular ShardedTensors (`not has_flattened_range`) + 2. N-D flattened ShardedTensors (`has_flattened_range`) + + (1) type are saved according to their original shape. + Type (2) however requires global shape adjustment for efficiency: + we treat [X, Y, Z] global shape tensor with local shape [x, y, z] + as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis + partitioned according to `flattened_range` slices. + This will need special handling while resharding. + + Args: + sh_tens (List[ShardedTensor]): list of sharded tensors to convert + rank (int, optional): current process rank passed to PyT ShardedTensor. + If None, assumes rank in the default pg. + load_legacy_1d_flatten_tensors (bool, optional): flag indicating if 1-D flattened tensors + should be loaded in a legacy way. Defaults to False. + + Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards. + + """ + if rank is None: + rank = torch.distributed.get_rank() + + some_sh_ten = sh_tens[0] + has_flattened_range = some_sh_ten.flattened_range is not None + + for sh_ten in sh_tens: + assert (sh_ten.flattened_range is not None) == has_flattened_range, sh_tens + if not sh_ten.data.is_contiguous(): + sh_ten.data = sh_ten.data.contiguous() + + if load_legacy_1d_flatten_tensors and len(some_sh_ten.global_shape) == 1: + # Legacy 1-D flattened tensors are loaded as non-flat regular ShardedTensors + has_flattened_range = False + + local_global_offsets = {} + + prepend_axis_num = sh_tens[0].prepend_axis_num + # Determine local shards according to tensor type (see docs) + if has_flattened_range: + # Type (3) case: N-D flattened ShardedTensors + for sh_ten in sh_tens: + local_global_offsets.setdefault(sh_ten.local_chunk_offset_in_global(), []).append( + sh_ten + ) + assert sh_ten.data.ndim == 1, sh_ten + sh_ten.data = sh_ten.data.view((1,) * len(sh_ten.global_shape) + (-1,)) + + # Global shape reformulation: + global_shape = nd_flattened_tensor_reformulated_global_shape(some_sh_ten) + offsets_shape = (1,) * len( + some_sh_ten.global_shape + ) # reformulated global shape has shape equal ti number of local chunks + + local_shards = [ + Shard.from_tensor_and_offsets( + sh_ten.data, + list( + sh_ten.local_chunk_offset_in_global() + (sh_ten.flattened_range.start,) + ), # additional flattened offset + rank, + ) + for sh_ten in sh_tens + ] + else: + # Type (1) case: non-flat regular ShardedTensors + for sh_ten in sh_tens: + local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten) + sh_ten.data = sh_ten.data.view( + (1,) * prepend_axis_num + sh_ten.local_shape + ) # adjust to prepended_axis_num + + global_shape = some_sh_ten.global_shape + offsets_shape = some_sh_ten.data.shape # includes prepended axes + + local_shards = [ + Shard.from_tensor_and_offsets( + sh_ten.data, list(sh_ten.global_offset), rank # simple case + ) + for sh_ten in sh_tens + ] + + # Create a ShardedTensor without invoking communication. Determine global shards + world_size = torch.distributed.get_world_size() + shard_metadata = [] + # NOTE: here we assume a regular grid of shards + for fragment_offsets in product(*map(range, some_sh_ten.axis_fragmentations)): + offset = tuple(map(lambda x: x[0] * x[1], zip(fragment_offsets, offsets_shape))) + if offset in local_global_offsets: + # local shard + placement = f"rank:{rank}/cuda" + for sh_ten in local_global_offsets[offset]: + if has_flattened_range: + assert offset == sh_ten.local_chunk_offset_in_global() + # This is not an actual offset, but an offset of the whole shard + # This is needed for a PyT Dist internal integrity check + offset = sh_ten.local_chunk_offset_in_global() + (0,) + size = (1,) * len(offsets_shape) + global_shape[-1:] + else: + size = sh_ten.data.shape + shard_metadata.append(ShardMetadata(offset, size, placement)) + + else: + # pylint: disable=line-too-long + # for shards from other ranks we provide simplistic data - this information will be discarded + # during TorchShardedTensor._init_from_local_shards_and_global_metadata call. + # Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size. + # The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS. + placement = f"rank:{(rank + 1) % world_size}/cuda" + if has_flattened_range: + offset = offset + (0,) + size = (1,) * len(offsets_shape) + global_shape[-1:] + else: + size = offsets_shape + shard_metadata.append(ShardMetadata(offset, size, placement)) + + tensor = some_sh_ten.data + sharded_tensor_metadata = ShardedTensorMetadata( + shards_metadata=shard_metadata, + size=torch.Size(global_shape), + tensor_properties=TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ), + ) + pyt_sh_ten = TorchShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=None + ) + # Store MCore related data as PyTShardedTensor attribute. + # This won't be stored in the checkpoint, only for runtime purposes + pyt_sh_ten.mcore_sh_ten = sh_ten.without_data() + pyt_sh_ten.mcore_metadata = {} + if has_flattened_range: + pyt_sh_ten.mcore_metadata['nd_reformulated_orig_global_shape'] = sh_ten.global_shape + return pyt_sh_ten + + +def mcore_to_pyt_state_dict( + state_dict: Dict[str, List[ShardedBase]], + is_loading: bool = False, + init_device: torch.device = torch.device("cpu"), + load_legacy_1d_flatten_tensors: bool = False, +) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]: + """Convert state dict with ShardedTensors and ShardedObjects + to state dict compatible with PyT Dist format. + + Operates in-place and returns the original state dict. + + Args: + state_dict (Dict[str, List[ShardedBase]]): flattened state dict, where values + are lists of either ShardedTensor or ShardedObjects. + is_loading (bool, optional): flag indicating if loading or saving. Defaults to False. + init_device (torch.device, optional): device to initialize potentially missing tensors + during loading. Defaults to 'cpu'. + + Returns (Dict[str, Union[TorchShardedTensor, io.BytesIO]]): original dictionary with values + converted either into PyT ShardedTensors or io.BytesIO. + + """ + rank = torch.distributed.get_rank() + pyt_state_dict = {} + + def _mcore_to_torch_sharded_tensor(sh_tens: List[ShardedTensor]) -> TorchShardedTensor: + """Build a PyT ShardedTensor from given shards. + + During loading: + - if data is None, initialize it with an empty tensor (will be used to copy the data into) + - if `allow_shape_mismatch` is True, the data is initialized with zeros + prior to loading (not all parts of the tensor will be read from the checkpoint) + """ + assert all(isinstance(sh_ten, ShardedTensor) for sh_ten in sh_tens), sh_tens + for sh_ten in sh_tens: + if sh_ten.data is None: + if is_loading: + sh_ten.init_data( + init_device, + init_fn=torch.zeros if sh_ten.allow_shape_mismatch else torch.empty, + ) + else: + raise CheckpointingException(f'`data` attr is None for {sh_ten}') + else: + sh_ten.data = sh_ten.data.detach() + if sh_ten.allow_shape_mismatch and is_loading: + sh_ten.data.zero_() + + torch_sh_ten = sharded_tensor_to_torch_sharded_tensor( + sh_tens, rank, load_legacy_1d_flatten_tensors + ) + torch_sh_ten.key = sh_tens[0].key + return torch_sh_ten + + def _mcore_to_torch_sharded_object(sh_objs: List[ShardedObject]) -> io.BytesIO: + """Build io.BytesIO from given sharded objects data.""" + assert all(isinstance(sh_obj, ShardedObject) for sh_obj in sh_objs), sh_objs + serialized_data = io.BytesIO() + torch.save([sh_obj.data for sh_obj in sh_objs], serialized_data) + return serialized_data + + for k, v in state_dict.items(): + if isinstance(v[0], ShardedTensor): + v = cast(List[ShardedTensor], v) + pyt_state_dict[k] = _mcore_to_torch_sharded_tensor(v) + else: + v = cast(List[ShardedObject], v) + pyt_state_dict[k] = _mcore_to_torch_sharded_object(v) + + return pyt_state_dict + + +def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]: + """Unwrap tensor from PyT ShardedTensor instance. + + If `prepend_axis_num` was non-zero (which is specific to MCore ShardedTensor) + then the tensor has additional singleton dimensions which should be squeezed. + """ + mcore_sh_ten = sh_ten.mcore_sh_ten + ret_tensors = [] + for sh in sh_ten.local_shards(): + ten = sh.tensor + if mcore_sh_ten.flattened_range is not None: + assert ten.shape[:-1] == (1,) * (len(ten.shape) - 1), ten.shape + ten = ten.view(-1) + else: + for _ in range(mcore_sh_ten.prepend_axis_num): + ten = ten.squeeze(0) + ret_tensors.append(ten) + return ret_tensors + + +def _replace_state_dict_keys_with_sharded_keys( + sharded_state_dict: ShardedStateDict, keep_only_main_replica: bool = False +) -> Tuple[Dict[str, List[ShardedBase]], FLATTEN_MAPPING, Dict[str, List[str]]]: + """Group ShardedBase objects by keys and + return mappings required for recreating the original dict.""" + flat_sd, flat_mapping = flatten_state_dict(sharded_state_dict) + rename_mapping = defaultdict(list) + new_flat_sd = defaultdict(list) + for k, sh_base in flat_sd.items(): + assert isinstance(sh_base, ShardedBase), type(sh_base) + key = sh_base.unique_key if isinstance(sh_base, ShardedObject) else sh_base.key + if is_main_replica(sh_base.replica_id) or not keep_only_main_replica: + rename_mapping[key].append(k) + new_flat_sd[key].append(sh_base) + return new_flat_sd, flat_mapping, rename_mapping + + +def _replace_sharded_keys_with_state_dict_keys( + state_dict: Dict[str, List[Union[torch.Tensor, io.BytesIO]]], + flat_mapping: FLATTEN_MAPPING, + rename_mapping: Dict[str, List[str]], +): + """Inverse of _replace_state_dict_keys_with_sharded_keys.""" + recovered_sd = {} + for k, tensors in state_dict.items(): + assert len(tensors) == len(rename_mapping[k]) + for ten, recovered_k in zip(tensors, rename_mapping[k]): + recovered_sd[recovered_k] = ten + + return unflatten_state_dict(recovered_sd, flat_mapping) + + +def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, list, Any]): + """Recursively update `x` keys, based on `keys_template`.""" + if isinstance(keys_template, dict): + assert isinstance(x, dict), type(x) + for k, v in keys_template.items(): + if not isinstance(k, str): + assert str(k) in x, (k, x.keys) + x[k] = x.pop(str(k)) + _restore_dict_types(x[k], v) + elif isinstance(keys_template, list): + assert isinstance(x, list), type(x) + for x_val, templ_val in zip(x, keys_template): + _restore_dict_types(x_val, templ_val) + + +@dataclass(frozen=True) +class MCoreSavePlan(SavePlan): + """SavePlan with MCore specific data.""" + + mcore_data: Dict[str, Dict[str, Any]] = None # Mcore related data about each tensor + + +class MCoreSavePlanner(DefaultSavePlanner): + """Differs with the default planner by saving BytesIO objects on all ranks. + + In the integration of MCore with PyT Distributed format, BytesIO objects + come from ShardedObjects, which should be treated as separate objects on each rank + (not common on all ranks). + + Also, the objects are already packed in io.BytesIO, so no need to redo it + in transform_object. + """ + + def __init__( + self, + *args, + dedup_replicated_tensors: Optional[bool] = None, + nd_flattened_global_shapes: Optional[Dict[str, Tuple[int, ...]]] = None, + can_run_decentralized_global_plan: bool = True, + **kwargs, + ) -> None: + # `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings + # during saving. + if get_torch_version() <= PkgVersion("2.2"): + kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors + super().__init__(*args, **kwargs) + self.nd_flattened_global_shapes = nd_flattened_global_shapes or {} + self.can_run_decentralized_global_plan = can_run_decentralized_global_plan + if can_run_decentralized_global_plan: + assert ( + not dedup_replicated_tensors + ), 'Cannot run decentralized plan with dedup_replicated_tensors=True' + assert ( + not self.flatten_state_dict + ), 'Cannot run decentralized plan with flatten_state_dict=True' + + def create_local_plan(self) -> SavePlan: + """Adds IOBytes write request on non-coordinator ranks.""" + + # NOTE: for PyT 2.4.0a0 we can't rely on `create_default_local_save_plan` because + # some alpha versions (specifically 2.4.0a0+f70bd71a48 in 24.06 NGC PyTorch container) + # add iobytes request only on coordinator ranks and some alpha versions + # (specifically 2.4.0a0+3bcc3cddb5 in 24.07 NGC PyTorch container) + # add those requests on all ranks. We inline a simplified version of this method below. + write_items = [] + for fqn, obj in self.state_dict.items(): + assert not HAVE_DTENSOR or not isinstance( + obj, DTensor + ) # translation from MCore ShardedTensors shouldn't result in DTensors + # Create write requests for tensor and bytes values. + # For MCore, these should be already non-duplicates. + write_items += _create_write_items(fqn, obj) + + self.plan = MCoreSavePlan( + items=write_items, + planner_data=self.mappings, + mcore_data={ + k: sh_ten.mcore_metadata + for k, sh_ten in self.state_dict.items() + if isinstance(sh_ten, TorchShardedTensor) + }, + ) + return self.plan + + def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]: + """Merges MCore data for all plans.""" + global_plan, metadata = super().create_global_plan(all_plans) + metadata.mcore_data = dict(ChainMap(*(plan.mcore_data for plan in all_plans))) + return global_plan, metadata + + def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: + """Nothing to do, just some checks. + + Args: + local_plan (SavePlan): local plan to turn to a global plan + (without interactions with other ranks) + + Returns: + SavePlan - locally transformed plan equivalent to the plan that would be + created by the coordinator + """ + assert ( + not self.flatten_state_dict + ), 'Cannot run decentralized plan with flatten_state_dict=True' + assert not local_plan.planner_data, 'Planner data should be empty with decentralized plan' + return local_plan + + def transform_object(self, write_item: WriteItem, object: Any): + """Make no transformations - bytes objects are already serialized.""" + return object + + +class MCoreLoadPlanner(DefaultLoadPlanner): + """Adds global shape validation to the default planner. + + If global shape validation can be ignored (shouldn't!), the default + load planner can be used. + """ + + def __init__( + self, *args, shapes_validation_sharded_tensors: Iterable[ShardedTensor] = (), **kwargs + ) -> None: + super().__init__(*args, **kwargs) + self.shapes_validation_sharded_tensors = shapes_validation_sharded_tensors + self._intermediate_read_item_and_target: Optional[Tuple[ReadItem, torch.Tensor]] = None + + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: + raise KeyError( + f"{sh_ten.key} from model not in state dict:" + f" {sorted(metadata.state_dict_metadata.keys())}" + ) + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + if not is_nd_flattened_tensor(sh_ten): + expected_shape = sh_ten.global_shape + else: + expected_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten) + if loaded_shape != expected_shape: + if is_nd_flattened_tensor(sh_ten) and len(sh_ten.global_shape) == 1: + # Handle legacy 1-D flattened tensors checkpoint format + # where the global shape is not stored in the metadata + expected_shape = sh_ten.global_shape + if loaded_shape == expected_shape: + continue + _msg = ( + f'Global shape mismatch for loaded ({loaded_shape})' + f' and expected ({expected_shape}) tensor' + f' for key {sh_ten.key}' + ) + raise CheckpointingException(_msg) + + def create_local_plan(self) -> LoadPlan: + """Runs additional shapes validation.""" + self._validate_global_shapes(self.metadata, self.shapes_validation_sharded_tensors) + return super().create_local_plan() + + def resolve_tensor(self, read_item: ReadItem): + """Override to add FP8 support. + + Narrowing the Float8Tensor can create incontiguous tensors and there are + no `copy` kernels for such cases. This method creates a contiguous FP8 + tensors so that the subsequent `copy_` in FileSystemReader succeeds. + Note that this requires tracking the original tensor + (as `self._intermediate_read_item_and_target` attribute) + and restoring it in `commit_tensor` method. + """ + target_tensor = super().resolve_tensor(read_item) + if ( + not target_tensor.is_contiguous() + and HAVE_TE + and isinstance(target_tensor, Float8Tensor) + ): + self._intermediate_read_item_and_target = (read_item, target_tensor) + target_tensor = Float8Tensor.make_like( + target_tensor, data=target_tensor._data.contiguous() + ) + return target_tensor + + def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: + """Restores the original FP8 tensor saved in `resolve_tensor`.""" + if self._intermediate_read_item_and_target is not None: + interm_read_item, target_tensor = self._intermediate_read_item_and_target + assert ( + interm_read_item is read_item + ), '`commit_tensor` method should be called right after `resolve_tensor`' + target_tensor.copy_(tensor) + tensor = target_tensor + self._intermediate_read_item_and_target = None + return super().commit_tensor(read_item, tensor) + + +class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy): + """Async save strategy for the PyT Distributed format. + + The idea is to translate MCore ShardedTensors into PyT ShardedTensors + and use the async-adjusted torch.distributed.checkpoint saving mechanism + provided by the FileSystemWriterAsync writer. + """ + + def __init__( + self, + backend: str, + version: int, + keep_only_main_replica: bool = True, + thread_count: int = 2, + cached_metadata: bool = False, + separation_hint: str = None, + ): + """Adds parameters specific to PyT Distributed format + Args: + backend (str): format backend string + version (int): format version + keep_only_main_replica (bool, optional): PyT Distributed has a mechanism + for deduplication, but replica_id aware deduplication is more coherent. + Default is True (recommended to keep it). + thread_count (int, optional): threads to use during saving. + Affects the number of files in the checkpoint (saving ranks * num_threads). + cached_metadata (bool, optional): Enables using cached global metadata to avoid + gathering local metadata every checkpointing invocation + separation_hint(str, optional): If provided, all tensors whose keys have this + prefix will be saved to a separate file. + """ + super().__init__(backend, version) + self.keep_only_main_replica = keep_only_main_replica + self.thread_count = thread_count + + # Cached SavePlans to skip plan in `save_state_dict_async_plan` + # cached outcome of `SavePlan.prepare_global_plan`, + # which aggregates local plans from all ranks + self.cached_central_plan: SavePlan = None + # cached outcome of `SavePlan.prepare_local_plan` describes how local state_dict is written + self.cached_local_plan: SavePlan = None + # Cached global metadata, only `coordinator` for dist-ckpt holds + # if central plans are consistent over iters + self.cached_global_metadata: Metadata = None + # This variable records if the ckpt structures are consistent + # so the following checkpoint savings reuse `cached_global_metadata` + self.validated_cache_reuse: bool = False + # The knob to enable cached metadata communication in saving + self.use_cached_ckpt_structure: bool = cached_metadata + + self.separation_hint = separation_hint + + self.validated_loaded_metadata_reuse = False + + def async_save( + self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path + ) -> AsyncRequest: + """Translates MCore ShardedTensors to PyT ShardedTensors & saves in PyT Distributed format. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to save + checkpoint_dir (Path): checkpoint directory + + Returns: None + """ + # Translate the state dict + (sharded_state_dict, flat_mapping, rename_mapping) = ( + _replace_state_dict_keys_with_sharded_keys( + sharded_state_dict, self.keep_only_main_replica + ) + ) + pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False) + # Use PyT saving mechanism + writer = FileSystemWriterAsync( + checkpoint_dir, separation_hint=self.separation_hint, thread_count=self.thread_count + ) + # This should be set differently if we run in a smaller process group than the default + coordinator = 0 + # Try twice to validate the generated `central_plan` is the same across iterations + # If so, reuse `cached_central_plan` and `cached_global_metadata` + # From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata` + # (return None) so `self.cached_global_metadata` is reused + args_cached_plans = None + loaded_all_plans = None + if self.use_cached_ckpt_structure: + loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None) + if loaded_all_plans is None: + logger.debug( + "no all_local_plans in metadata - can't verify global metadata reuse..." + ) + + args_cached_plans = ( + self.cached_central_plan, + self.cached_local_plan, + self.validated_cache_reuse, + ) + + ( + save_state_dict_ret, + self.cached_central_plan, + self.cached_local_plan, + self.validated_cache_reuse, + self.validated_loaded_metadata_reuse, + ) = save_state_dict_async_plan( + pyt_state_dict, + writer, + None, + coordinator, + planner=MCoreSavePlanner( + dedup_replicated_tensors=not self.keep_only_main_replica, flatten_state_dict=False + ), + cached_ckpt_structure=args_cached_plans, + loaded_all_plans=loaded_all_plans, + ) + rank = torch.distributed.get_rank() + if self.use_cached_ckpt_structure: + if ( + loaded_all_plans + and self.cached_global_metadata + and self.validated_loaded_metadata_reuse + ): + if coordinator == rank: + logger.debug( + f"rank: {rank}, reuse global metadata from loaded" + f" .metadata, {save_state_dict_ret[1]}" + ) + save_state_dict_ret = list(save_state_dict_ret) + save_state_dict_ret[1] = self.cached_global_metadata + + elif self.validated_cache_reuse: + logger.debug(f"rank: {rank}, cache validated") + if save_state_dict_ret[1]: # when global_metadata is not cached + self.cached_global_metadata = save_state_dict_ret[1] # Cache Metadata + # Only Coordinator rank holds cached global_metadata + # (None is returned for global_metadata) + elif coordinator == rank: + logger.debug( + f"rank: {rank}, reuse global metadata cached from previous" + f" save iteration, {save_state_dict_ret[1]}" + ) + save_state_dict_ret = list(save_state_dict_ret) + save_state_dict_ret[1] = self.cached_global_metadata + + return self._get_save_and_finalize_callbacks(writer, save_state_dict_ret) + + def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret) -> AsyncRequest: + save_fn_args = writer.get_save_function_and_args() + save_fn, preload_fn, save_args = save_fn_args + + def finalize_fn(): + save_state_dict_async_finalize(*save_state_dict_ret) + torch.distributed.barrier() + + return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn) + + def can_handle_sharded_objects(self): + return True + + +def get_reformulation_metadata( + sharded_state_dict: ShardedStateDict, checkpoint_dir: Path +) -> Dict[str, TensorReformulationMetadata]: + """Reads MCore data for N-D flattened tensors from checkpoint metadata during ckpt load. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to load + checkpoint_dir (Path): checkpoint directory + + Returns: + Dict[str, TensorReformulationMetadata] - dictionary that maps keys of every + N-D flattened tensor from the sharded_state_dict to its original global shape + as stored in `mcore_data` in the checkpoint. + """ + ckpt_metadata = FileSystemReader(checkpoint_dir).read_metadata() + reformulation_metadata = {} + for sh_ten in nested_values(sharded_state_dict): + if not is_nd_flattened_tensor(sh_ten): + continue + try: + ckpt_global_shape = ckpt_metadata.mcore_data[sh_ten.key][ + 'nd_reformulated_orig_global_shape' + ] + except KeyError as e: + if len(sh_ten.global_shape) == 1: + warnings.warn( + f'Legacy checkpoint format detected for 1-D flattened tensor {sh_ten}. ' + 'Skip metadata reformulation.' + ) + continue + raise CheckpointingException( + f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} ' + f'in checkpoint metadata: {ckpt_metadata.mcore_data}' + ) from e + + reformulation_metadata[sh_ten.key] = TensorReformulationMetadata( + ckpt_global_shape, ckpt_metadata.state_dict_metadata[sh_ten.key].size + ) + return reformulation_metadata + + +class TorchDistLoadShardedStrategy(LoadShardedStrategy): + """Basic load strategy for the PyT Distributed format.""" + + def __init__(self): + self.cached_global_metadata: Optional[Metadata] = None + super().__init__() + + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict: + """Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict with mapping + information to instruct loading + checkpoint_dir (Path): checkpoint directory + + Returns: loaded state dict + """ + # Apply N-D tensors resharding + reformulation_metadata = get_reformulation_metadata(sharded_state_dict, checkpoint_dir) + sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation( + sharded_state_dict, reformulation_metadata + ) + + # Check if there are legacy 1-D flattened tensors in the checkpoint + has_legacy_1d_flattened_tensors = False + for sh_ten in nested_values(sharded_state_dict): + if is_nd_flattened_tensor(sh_ten) and sh_ten.key not in reformulation_metadata: + has_legacy_1d_flattened_tensors = True + break + + flexible_shape_sharded_tensors = [ + sh_ten + for sh_ten in nested_values(sharded_state_dict) + if isinstance(sh_ten, ShardedTensor) and not sh_ten.allow_shape_mismatch + ] + + orig_sharded_state_dict = sharded_state_dict + # MCore state dict to PyT Distributed compatible + (sharded_state_dict, flat_mapping, rename_mapping) = ( + _replace_state_dict_keys_with_sharded_keys(sharded_state_dict) + ) + pyt_state_dict = mcore_to_pyt_state_dict( + sharded_state_dict, True, load_legacy_1d_flatten_tensors=has_legacy_1d_flattened_tensors + ) + # Load PyT Distributed format + fsr = CachedMetadataFileSystemReader(checkpoint_dir) + checkpoint.load_state_dict( + pyt_state_dict, + fsr, + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors + ), + ) + + self.cached_global_metadata = ( + fsr.read_metadata() + ) # no storage interaction thanks to caching + + pyt_state_dict = cast( + Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict + ) + # Unwrap ShardedTensors and return to original state dict + mcore_state_dict = { + k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v) + for k, v in pyt_state_dict.items() + } + mcore_state_dict = _replace_sharded_keys_with_state_dict_keys( + mcore_state_dict, flat_mapping, rename_mapping + ) + _restore_dict_types(mcore_state_dict, orig_sharded_state_dict) + # Apply N-D tensors resharding postprocessing + mcore_state_dict = restore_nd_flattened_tensors_formulation( + mcore_state_dict, formulation_restore_data + ) + return mcore_state_dict + + def load_tensors_metadata(self, checkpoint_dir: Path, metadata: Metadata = None): + """Uses tensors metadata stored in the metadata file.""" + if metadata is None: + fs_reader = FileSystemReader(checkpoint_dir) + metadata = fs_reader.read_metadata() + + mcore_data = getattr(metadata, 'mcore_data', {}) + sharded_metadata = {} + for k, tp in metadata.state_dict_metadata.items(): + if not isinstance(tp, TensorStorageMetadata): + continue # load only tensors + + nd_orig_global_shape = mcore_data.get(k, {}).get('nd_reformulated_orig_global_shape') + if nd_orig_global_shape is None: + # Regular tensor + sharded_metadata[k] = ShardedTensor.from_rank_offsets( + k, torch.empty(tp.size, **tp.properties.__dict__, device='meta') + ).without_data() + else: + # N-D flattened tensor + unflat_ten = torch.empty( + nd_orig_global_shape, **tp.properties.__dict__, device='meta' + ) + flat_ten = unflat_ten.flatten() + sharded_metadata[k] = ShardedTensor.from_rank_offsets_flat( + k, + flat_ten, + unflat_ten.shape, + flattened_range=slice(0, unflat_ten.numel()), # whole slice + ).without_data() + + return sharded_metadata + + def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: + """Uses tensors and objects metadata stored in the metadata file.""" + fs_reader = FileSystemReader(checkpoint_dir) + metadata = fs_reader.read_metadata() + + sharded_metadata = {} + for metadata_key, storage_metadata in metadata.state_dict_metadata.items(): + if not isinstance(storage_metadata, BytesStorageMetadata): + continue + sh_obj = ShardedObject.empty_from_unique_key(metadata_key) + sharded_metadata[sh_obj.unique_key] = sh_obj + + sharded_metadata.update(self.load_tensors_metadata(checkpoint_dir, metadata)) + return sharded_metadata + + def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str): + """Removes checkpoint files whose keys have the given prefix. + + Performs the following steps: + 1. checks whether there are files that start with the key_prefix + 2. loads metadata + 3. removes all entries from the metadata that start with the key_prefix + 4. resaves the new metadata and removes the old metadata + 5. removes the relevant files + """ + + assert is_torch_min_version( + "2.3.0" + ), f'torch >= 2.3.0 is required for remove_sharded_tensors' + + distckpt_files = [f for f in os.listdir(checkpoint_dir) if f.endswith("distcp")] + files_to_remove = [f for f in distckpt_files if f.startswith(key_prefix)] + + if not files_to_remove: + warnings.warn( + f'There are no files in {checkpoint_dir} that begin with "{key_prefix}".' + f' Skipping removal.' + ) + return + + fs_reader = FileSystemReader(checkpoint_dir) + original_metadata = fs_reader.read_metadata() + + new_state_dict_metadata = {} + new_planner_data = {} + new_storage_data = {} + for k in original_metadata.state_dict_metadata.keys(): + if k.startswith(key_prefix): + continue + new_state_dict_metadata[k] = original_metadata.state_dict_metadata[k] + for k in original_metadata.planner_data.keys(): + if k.startswith(key_prefix): + continue + new_planner_data[k] = original_metadata.planner_data[k] + for k in original_metadata.storage_data.keys(): + if k.fqn.startswith(key_prefix): + continue + new_storage_data[k] = original_metadata.storage_data[k] + metadata = Metadata( + state_dict_metadata=new_state_dict_metadata, + planner_data=new_planner_data, + storage_data=new_storage_data, + ) + fs_writer = FileSystemWriter(checkpoint_dir) + metadata_filename = cast(Path, fs_writer.fs.concat_path(fs_writer.path, _metadata_fn)) + tmp_path = cast( + metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.tmp") + ) + old_path = cast( + metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.bck") + ) + ## save the new metadata + with fs_writer.fs.create_stream(tmp_path, "wb") as metadata_file: + pickle.dump(metadata, metadata_file) + try: + os.fsync(metadata_file.fileno()) + except AttributeError: + os.sync() + ## move the old metadata + fs_writer.fs.rename(fs_writer.metadata_path, old_path) + try: + ## rename the new metadata + fs_writer.fs.rename(tmp_path, fs_writer.metadata_path) + + ## finally, remove the files we want to drop + for f in files_to_remove: + fs_writer.fs.rm_file(checkpoint_dir / f) + except Exception as e: + fs_writer.fs.rename(old_path, fs_writer.metadata_path) + raise e + else: + fs_writer.fs.rm_file(old_path) + + def can_handle_sharded_objects(self): + return True + + def check_backend_compatibility(self, loaded_version): + pass # TODO + + def check_version_compatibility(self, loaded_version): + pass # TODO diff --git a/megatron/core/dist_checkpointing/strategies/two_stage.py b/megatron/core/dist_checkpointing/strategies/two_stage.py index a9844ff6e5..50b31e2497 100644 --- a/megatron/core/dist_checkpointing/strategies/two_stage.py +++ b/megatron/core/dist_checkpointing/strategies/two_stage.py @@ -1,23 +1,22 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. """ 2-stage checkpoint loading. """ -import os import time from collections import defaultdict from dataclasses import dataclass from functools import partial, wraps from itertools import chain -from logging import DEBUG, INFO, StreamHandler, getLogger +from logging import getLogger from operator import attrgetter, itemgetter from pathlib import Path -from typing import Iterable, List, NamedTuple, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values -from ..mapping import ShardedStateDict, ShardedTensor, StateDict +from ..mapping import ShardedStateDict, ShardedTensor from .base import LoadShardedStrategy -from .tensorstore import TensorStoreLoadShardedStrategy, _load_from_array, open_ts_array +from .tensorstore import _load_from_array, open_ts_array from .zarr import flatten_range, load_zarr_based_sharded_metadata _import_trigger = None @@ -26,9 +25,16 @@ timers = defaultdict(list) logger = getLogger(__name__) +logger.warning( + 'megatron.core.dist_checkpointing.two_stage module is deprecated' + ' and will be removed in Megatron-Core v0.12. Please use' + ' FullyParallelLoadStrategyWrapper to accomplish a parallelized checkpoint load.' +) def timed(verbose=True): + """Timing decorator.""" + def timed_dec(fn): name = fn.__name__ @@ -59,14 +65,12 @@ class _ShardedTensorMetadata: def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor): - return ( - sharded_tensor.key, - sharded_tensor.global_offset, - ) + """Id of a sharded tensor.""" + return (sharded_tensor.key, sharded_tensor.global_offset) class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): - """ Loads one checkpoint replica from storage and broadcasts to other nodes. + """Loads one checkpoint replica from storage and broadcasts to other nodes. This strategy loads checkpoint from storage on minimal set of nodes and distributes the checkpoint to other nodes with torch.distributed. @@ -77,19 +81,18 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): 1. Exchange ShardedTensors metadata between all nodes 2. Align needed tensors within DP groups 3. For each globally unique tensor: - a) on one of the ranks load it from storage to CPU and move to CUDA - b) allocate CUDA tensor on other ranks - c) broadcast within DP group - d) copy tensor content to the model param location - e) free tensor buffers from a) and b) + 3.a) on one of the ranks load it from storage to CPU and move to CUDA + 3.b) allocate CUDA tensor on other ranks + 3.c) broadcast within DP group + 3.d) copy tensor content to the model param location + 3.e) free tensor buffers from a) and b) Notes: 1. Loading and broadcasting is done sequentially to avoid both host and device OOMs 2. There is a lot of overlap potential between all three steps done for each tensor: - a) loading from storage to numpy - b) moving CPU tensors to CUDA - c) broadcast - + 2.a) loading from storage to numpy + 2.b) moving CPU tensors to CUDA + 2.c) broadcast """ def __init__(self, data_parallel_group, cpu_transfer=True): @@ -105,15 +108,19 @@ def __init__(self, data_parallel_group, cpu_transfer=True): self.global_rank = torch.distributed.get_rank() def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Main load method.""" self.maybe_init_gloo_group() all_tensors_sorted = self._build_load_plan(sharded_state_dict) self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir) - self.summarize_load_times() + # TODO: fix hang in summarize_load_times + # self.summarize_load_times() return sharded_state_dict def summarize_load_times(self): + """Summarize load times.""" torch.distributed.barrier() logger.info('Checkpoint loading finished. Summary:') + # TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs for key, times in sorted(timers.items()): times_sum = sum(times) max_times = torch.tensor([times_sum], device='cuda') @@ -126,6 +133,7 @@ def summarize_load_times(self): @timed(verbose=False) def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata): + """Load tensor from storage.""" logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init') ret = _load_from_array( ten_meta.sharded_tensor_no_data, @@ -138,12 +146,15 @@ def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetad @timed() def maybe_init_gloo_group(self): + """Create Gloo groups.""" if not self.cpu_transfer: return all_groups = [None] * torch.distributed.get_world_size() torch.distributed.all_gather_object(all_groups, self.dp_group_ranks) all_groups = set(tuple(sorted(gr)) for gr in all_groups) for group_ranks in sorted(all_groups): + # "two_stage" module will be deprecated, so not replace new_group() + # with ...parallel_state.create_group() func setting group_desc here. gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo') if self.global_rank in group_ranks: self.data_parallel_group = gloo_pg @@ -176,7 +187,7 @@ def _build_load_plan( @timed() def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]): - """ Group tensors by chunk and then pick the tensor with the lowest rank. + """Group tensors by chunk and then pick the tensor with the lowest rank. NOTE: with proper loading overlap, loading from randomized ranks (instead of the smallest one) could be beneficial here. @@ -213,7 +224,8 @@ def _exchange_loaded_tensors( ) logger.debug( - f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})' + f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}\ +({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})' ) torch.distributed.broadcast( exchange_tensor, group=self.data_parallel_group, src=src_rank diff --git a/megatron/core/dist_checkpointing/strategies/zarr.py b/megatron/core/dist_checkpointing/strategies/zarr.py index 057d3248ee..6a52b12257 100644 --- a/megatron/core/dist_checkpointing/strategies/zarr.py +++ b/megatron/core/dist_checkpointing/strategies/zarr.py @@ -1,20 +1,28 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. """ Strategies using Zarr as an underlying format. """ +import logging import os from functools import partial from logging import getLogger from pathlib import Path -from typing import Callable, List, Tuple +from typing import Callable, List, Optional, Tuple import numpy as np import torch import zarr from ..core import CheckpointingException -from ..dict_utils import dict_list_map_inplace +from ..dict_utils import dict_list_map_inplace, nested_values from ..mapping import ShardedStateDict, ShardedTensor, is_main_replica -from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies +from .base import ( + LoadShardedStrategy, + SaveShardedStrategy, + StrategyAction, + register_default_strategy, +) + +logger = logging.getLogger(__name__) numpy_to_torch_dtype_dict = { np.dtype('bool'): torch.bool, @@ -34,7 +42,8 @@ try: - import tensorstore + # Register a bfloat16 type with this import + import tensorstore # pylint: disable=unused-import HAS_BFLOAT16 = True numpy_to_torch_dtype_dict[np.dtype('bfloat16')] = torch.bfloat16 @@ -42,13 +51,28 @@ except ImportError: HAS_BFLOAT16 = False -_import_trigger = None - logger = getLogger(__name__) +def register_default_zarr_strategies(): + """Register default strategies related to Zarr backend.""" + register_default_strategy( + StrategyAction.SAVE_SHARDED, 'zarr', 1, ZarrSaveShardedStrategy('zarr', 1) + ) + + class ZarrSaveShardedStrategy(SaveShardedStrategy): - def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path): + """Save strategy for Zarr backend.""" + + def __init__(self, backend: str, version: int): + super().__init__(backend, version) + logger.warning( + f'`zarr` distributed checkpoint backend is deprecated.' + ' Please switch to PyTorch Distributed format (`torch_dist`).' + ) + + def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + sharded_tensors = list(nested_values(sharded_state_dict)) arrays = _create_or_open_zarr_arrays(sharded_tensors, checkpoint_dir) for ten, arr in zip(sharded_tensors, arrays): _save_to_existing_array(ten, arr) @@ -57,24 +81,41 @@ def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path): def _create_or_open_zarr_arrays( sharded_tensors: List[ShardedTensor], checkpoint_dir: Path -) -> List[zarr.Array]: +) -> List[Optional[zarr.Array]]: + """Returns list of zarr arrays corresponding to given tensors. + + For a sharded tensors that: + a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array + b) is main replica but not the first chunk, + opens the arrays created in (a) (possibly by other process) + c) otherwise, sets the corresponding array to None since it won't be used + + Args: + sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank + that will be saved to checkpoint + checkpoint_dir (Path): checkpoint in which the arrays will be created + """ arrays = [] for ten in sharded_tensors: - if _should_create_array(ten): - _create_zarr_array(ten, checkpoint_dir) - # TODO: maybe reuse the opened arrays + arr = _create_zarr_array(ten, checkpoint_dir) if _should_create_array(ten) else None + arrays.append(arr) torch.distributed.barrier() - for ten in sharded_tensors: - # if is_main_replica(ten.replica_id) and set(ten.global_offset) == {0}: - # continue + # Open arrays created above by other processes + for arr_idx, ten in enumerate(sharded_tensors): + if arrays[arr_idx] is not None: + # array created by this process + assert _should_create_array(ten), ten + continue + if not is_main_replica(ten.replica_id): + # this array won't be needed for saving and can stay None + continue open_kwargs = {} if ten.flattened_range is not None: open_kwargs['synchronizer'] = zarr.ProcessSynchronizer( str(checkpoint_dir / f'{ten.key}.sync') ) - arr = zarr.open(checkpoint_dir / ten.key, 'r+', **open_kwargs) - arrays.append(arr) + arrays[arr_idx] = _open_zarr_array_verbose(checkpoint_dir / ten.key, 'r+', **open_kwargs) return arrays @@ -86,9 +127,10 @@ def _should_create_array(ten: ShardedTensor): ) -def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: zarr.Array): +def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: Optional[zarr.Array]): if not is_main_replica(sharded_tensor.replica_id): return + assert arr is not None x = sharded_tensor.data x = x.detach().cpu() torch.cuda.synchronize() @@ -117,6 +159,7 @@ def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path): fill_value=None, write_empty_chunks=True, ) + logger.debug(f'Created a new Zarr array at {checkpoint_dir / sharded_tensor.key}') except zarr.errors.ContainsArrayError as e: raise CheckpointingException( f'Array {checkpoint_dir / sharded_tensor.key} already exists' @@ -130,6 +173,8 @@ def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path): class ZarrLoadShardedStrategy(LoadShardedStrategy): + """Load strategy for the Zarr backend.""" + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): dict_list_map_inplace( partial(_load_from_array, checkpoint_dir=checkpoint_dir), sharded_state_dict @@ -152,12 +197,7 @@ def check_version_compatibility(self, loaded_version): def _load_from_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path): assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor) - try: - arr = zarr.open(checkpoint_dir / sharded_tensor.key, 'r') - except zarr.errors.PathNotFoundError as e: - raise CheckpointingException( - f'Array {checkpoint_dir / sharded_tensor.key} not found' - ) from e + arr = _open_zarr_array_verbose(checkpoint_dir / sharded_tensor.key, 'r') if not sharded_tensor.allow_shape_mismatch and sharded_tensor.global_shape != arr.shape: _msg = ( @@ -171,7 +211,22 @@ def _load_from_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path): return postprocess_numpy_array(x, sharded_tensor) +def _open_zarr_array_verbose(path: Path, mode: str, **open_kwargs): + try: + return zarr.open(str(path), mode, **open_kwargs) + except zarr.errors.PathNotFoundError as e: + ckpt_dir = path.parent + err_msg = f'Array {path} not found' + if ckpt_dir.exists(): + ckpt_files = [f.name for f in ckpt_dir.iterdir()] + logger.debug(f'{err_msg}. Checkpoint directory {ckpt_dir} content: {ckpt_files}') + else: + err_msg += f'. Checkpoint directory {ckpt_dir} does not exist.' + raise CheckpointingException(err_msg) from e + + def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True): + """Turn numpy array to torch tensor.""" x = loaded_array if HAS_BFLOAT16 and x.dtype == np.dtype('bfloat16'): x = x.astype(np.dtype('float32')) @@ -199,10 +254,12 @@ def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range= def flatten_range(sharded_tensor, x): + """Apply flattened range to a tensor.""" return x.flatten()[sharded_tensor.flattened_range] def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor): + """Pad tensor to the expected shape.""" pad_args = [] assert len(x.shape) == len(expected_sharded_ten.local_shape) # Reversed iteration order because F.pad expects so @@ -214,9 +271,10 @@ def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor): if x_sh == exp_sh: pad_args.extend((0, 0)) elif x_sh > exp_sh: - assert ( - False - ), f'Expected shape ({exp_sh}) smaller than actual ({x_sh}) for {repr(expected_sharded_ten)}' + assert False, ( + f'Expected shape ({exp_sh}) smaller than actual ({x_sh})' + f' for {repr(expected_sharded_ten)}' + ) else: pad_args.extend((0, exp_sh - x_sh)) # TODO: behavior control with envvar is for testing purposes only, remove it @@ -239,7 +297,7 @@ def load_zarr_based_sharded_metadata( ) -> ShardedStateDict: """Load metadata of Zarr arrays. - Arguments: + Args: checkpoint_dir (str): checkpoint root directory get_shape_dtype_fn (str -> ((int, ...), np.dtype)): a function returning an array shape and dtype for a given Zarr array path @@ -261,9 +319,3 @@ def load_zarr_based_sharded_metadata( tuple(1 for _ in arr_shape), ) return sharded_state_dict - - -# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy() -default_strategies[StrategyAction.SAVE_SHARDED.value][('zarr', 1)] = ZarrSaveShardedStrategy( - 'zarr', 1 -) diff --git a/megatron/core/dist_checkpointing/tensor_aware_state_dict.py b/megatron/core/dist_checkpointing/tensor_aware_state_dict.py new file mode 100644 index 0000000000..6521d869e6 --- /dev/null +++ b/megatron/core/dist_checkpointing/tensor_aware_state_dict.py @@ -0,0 +1,347 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Utilities for transforming state_dict, including a tensor-aware implementation.""" + +import logging +from dataclasses import dataclass +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple + +import torch +from nvidia_resiliency_ext.checkpointing.local.base_state_dict import TensorAwareStateDict + +from .dict_utils import dict_list_map_inplace, dict_list_map_outplace, merge, nested_values +from .exchange_utils import ( + ShardDistribution, + determine_main_replica_uniform_distribution, + exchange_by_distribution, +) +from .mapping import ShardedObject, ShardedStateDict, ShardedTensor, StateDict, apply_factory_merges +from .state_dict_utils import load_preprocess, save_preprocess +from .utils import ( + _sharded_object_id, + _sharded_tensor_shard_id, + debug_time, + extract_sharded_base, + zip_strict, +) +from .validation import determine_global_metadata, validate_sharding_integrity + +logger = logging.getLogger(__name__) + + +@dataclass +class MCoreTensorAwareStateDict(TensorAwareStateDict): + """ + MCore-specific class defining the interface between the MCore state dict and checkpoint manager. + + This class distinguishes between raw objects, the common state dict, and sharded state dicts + (tensor parts). It also handles optional metadata needed for fully parallel save/load. + """ + + common: StateDict + sharded_state_dict: ShardedStateDict + _is_hollow: bool = False + + @staticmethod + def _validate_params(algo): + if algo != 'atomic' and algo != 'fully_parallel': + raise NotImplementedError( + 'Only "atomic" and "fully_parallel" sharding algorithms are supported.' + ) + + @staticmethod + def _get_distribution( + fully_parallel, sharded_part, parallelization_group, cached_distribution=None + ): + if fully_parallel: + if cached_distribution is None: + distribution = determine_main_replica_uniform_distribution( + sharded_part, parallelization_group, True + ) + logger.debug(f'MCore_TASD._get_distribution calculated distribution') + else: + distribution = cached_distribution + logger.debug(f'MCore_TASD._get_distribution used cache') + else: + distribution = (None, None, None, None) + logger.debug(f'MCore_TASD._get_distribution returned empty distribution') + return distribution + + @staticmethod + def _remove_redundant_data( + fully_parallel, sharded_part, shard_to_saving_rank, parallelization_group + ): + if fully_parallel: + for sh_base in nested_values(sharded_part): + # TODO remove redundant objects as well + if isinstance(sh_base, ShardedTensor): + shard_id = _sharded_tensor_shard_id(sh_base) + if shard_to_saving_rank[shard_id] != torch.distributed.get_rank( + group=parallelization_group + ): + sh_base.data = None + + @classmethod + @debug_time("from_state_dict", logger) + def from_state_dict( + cls, + sharded_state_dict: ShardedStateDict, + algo: str = 'fully_parallel', + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, + cached_metadata: ShardDistribution = None, + ) -> Tuple[TensorAwareStateDict, ShardDistribution]: + """ + Constructs a TensorAwareStateDict from a sharded state dictionary. + + This method preprocesses the input `sharded_state_dict`, validates parameters, + and extracts the necessary data to create an instance of `MCoreTensorAwareStateDict`. + + Args: + sharded_state_dict: The input sharded state dictionary to be converted. + algo (str, optional): Initialization algorithm. Defaults to 'fully_parallel'. + - 'fully_parallel' enables fully parallel initialization. + parallelization_group (Optional): A distributed process group for parallelization. + cached_metadata (Optional): Precomputed metadata from previous saves. + - Reuses data that doesn't need recalculation, optimizing the creation process. + + Returns: + TensorAwareStateDict: An instance initialized with the provided sharded state dictionary + and optional cached metadata. + - The metadata is stored in memory to speed up future saves. + """ + with debug_time("_get_distribution", logger): + cls._validate_params(algo) + fully_parallel = algo == 'fully_parallel' + sharded_part, common_state_dict = save_preprocess( + sharded_state_dict, cached_metadata is None + ) + cacheable_distribution = cls._get_distribution( + fully_parallel, sharded_part, parallelization_group, cached_metadata + ) + if cacheable_distribution is not None: + shard_to_saving_rank, _, _, _ = cacheable_distribution + cls._remove_redundant_data( + fully_parallel, sharded_part, shard_to_saving_rank, parallelization_group + ) + + return ( + MCoreTensorAwareStateDict(common=common_state_dict, sharded_state_dict=sharded_part), + cacheable_distribution, + ) + + @property + def is_hollow(self): + """ + True iff tensors had been extracted and have not been inserted back yet. + """ + return self._is_hollow + + @property + def _sharded_tensors(self): + # Three possible states for sharded_tensor: + # 1. sharded_tensor with data (.data = tensor) + # 2. sharded_tensor hollow (.data = None, .orig_device = orig_device) + # 3. removed sharded_tensor (.data = None, no device information) + # TODO: Consider simplifying by removing the entire sharded_tensor instead of just the data + if self.is_hollow: + for sh_base in nested_values(self.sharded_state_dict): + # FIXME: Hacky way to store the original device of the popped tensor + if isinstance(sh_base, ShardedTensor) and hasattr(sh_base, 'orig_device'): + yield sh_base + else: + for sh_base in nested_values(self.sharded_state_dict): + if isinstance(sh_base, ShardedTensor) and sh_base.data is not None: + yield sh_base + + @property + def tensors(self) -> Iterator[torch.Tensor]: + """ + Get the tensor data from the state dict. + """ + assert not self.is_hollow # TODO raise exception + return map(lambda sh_ten: sh_ten.data, self._sharded_tensors) + + @property + def common_state_dict(self) -> Dict: + """ + Get the common state dict from the state dict. + """ + return self.common + + def pop_tensors(self) -> List[torch.Tensor]: + """ + Extracts the tensor data from the wrapped state dict, preserving metadata. + + Replaces the tensor data in sharded_tensors with device type of extracted tensors. + After this operation, the state dictionary is "hollow", containing no tensor data. + Further calls to `pop_tensor` will raise an error. + + @return List of extracted tensors + """ + assert not self.is_hollow # TODO raise exception + result = [] + for sh_ten in self._sharded_tensors: + result.append(sh_ten.data) + # FIXME: Hacky way to store the original device, which is not included in the metadata + setattr(sh_ten, 'orig_device', sh_ten.data.device.type) + sh_ten.data = None + self._is_hollow = True + return result + + def insert_tensors(self, tensor_data: Iterable[torch.Tensor]): + """ + Reverse of `pop_tensors`. Replaces device type in sharded_tensors with actual values + Value of `self` is considered to be the same after: + ``` + self.insert_tensors(self.pop_tensors()) + ``` + """ + assert self.is_hollow # TODO raise exception + for sh_ten, ten in zip_strict(self._sharded_tensors, tensor_data): + # FIXME: Hacky way to store the original device + if sh_ten.orig_device == ten.device.type: + delattr(sh_ten, 'orig_device') + # Tensor might be on non-original device + sh_ten.data = ten + self._is_hollow = False + + def init_tensors(self): + """ + Initializes empty tensors with the same properties as the original tensors. + + This function should only be called after the original tensors have been popped. + It ensures that the newly created empty tensors match the shape, + dtype, and device of the originals, but contain no data. + """ + assert self.is_hollow # TODO raise exception + for sh_ten in self._sharded_tensors: + # Hacky way to retrieve the original device + sh_ten.init_data(sh_ten.orig_device) + delattr(sh_ten, 'orig_device') + self._is_hollow = False + + def copy_tensors_to_cpu(self, non_blocking=False): + """ + Stores CPU copies of tensors in the state_dict, replacing the originals, + but without destroying them. + The original devices are remembered for restoration with restore_tensor_device(). + Using non_blocking=True allows for asynchronous copying. + """ + assert not self.is_hollow # TODO raise exception + for sh_ten in self._sharded_tensors: + if sh_ten.data.device.type == 'cpu': + # Skip cloning if it's already confirmed to be a copy + if not hasattr(sh_ten, 'orig_device'): + sh_ten.data = sh_ten.data.clone() + else: + # FIXME: Hacky way to store the original device + if not hasattr(sh_ten, 'orig_device'): + setattr(sh_ten, 'orig_device', sh_ten.data.device.type) + sh_ten.data = sh_ten.data.detach().to("cpu", non_blocking=non_blocking) + + def restore_tensor_device(self, non_blocking=True): + """ + Restores all tensors to their original devices, if a move is required. + Using non_blocking=True allows for asynchronous copying. + """ + assert not self.is_hollow # TODO raise exception + for sh_ten in self._sharded_tensors: + # FIXME: Hacky way to store the original device + if hasattr(sh_ten, 'orig_device'): + sh_ten.data = sh_ten.data.to(sh_ten.orig_device, non_blocking=non_blocking) + delattr(sh_ten, 'orig_device') + + def _insert_sharded_data( + self, fully_parallel, sharded_part, parallelization_group, exchange_algo + ): + loaded_tensors = {} + for sh_ten in self._sharded_tensors: + loaded_tensors[_sharded_tensor_shard_id(sh_ten)] = sh_ten.data + if fully_parallel: + with debug_time("_get_distribution", logger): + distribution = self._get_distribution( + fully_parallel, sharded_part, parallelization_group + ) + if distribution is not None: + unloaded_shards = {} + for sh_base in nested_values(sharded_part): + # TODO retrieve redundant ShardedObjects once removed in _remove_redundant_data + if isinstance(sh_base, ShardedTensor): + shard_id = _sharded_tensor_shard_id(sh_base) + if shard_id not in loaded_tensors: + unloaded_shards[shard_id] = sh_base + + with debug_time("exchange_by_distribution", logger): + loaded_tensors = exchange_by_distribution( + loaded_tensors, + unloaded_shards, + distribution, + parallelization_group, + exchange_algo, + ) + torch.cuda.synchronize() + loaded_objects = {} + for sh_base in nested_values(self.sharded_state_dict): + if not isinstance(sh_base, ShardedTensor): + assert isinstance(sh_base, ShardedObject) + loaded_objects[_sharded_object_id(sh_base)] = sh_base.data + + def load_sharded_base(x: Any): + if isinstance(x, ShardedTensor): + shard_id = _sharded_tensor_shard_id(x) + assert shard_id in loaded_tensors, (x, shard_id, loaded_tensors.keys()) + x = loaded_tensors[shard_id] + if isinstance(x, ShardedObject): + object_id = _sharded_object_id(x) + assert object_id in loaded_objects, (x, object_id, loaded_objects.keys()) + x = loaded_objects[object_id] + return x + + dict_list_map_inplace(load_sharded_base, sharded_part) + + @debug_time("to_state_dict", logger) + def to_state_dict( + self, + sharded_state_dict: ShardedStateDict, + algo: str = 'atomic', + exchange_algo: str = 'broadcast', + validate_access_integrity: bool = True, + parallelization_group: Optional[torch.distributed.ProcessGroup] = None, + ): + """ + Convert tensor-aware dict back to the original state_dict + """ + with debug_time("load_preprocess_and_state_dict_manipulations", logger): + assert not self.is_hollow # TODO raise exception + self._validate_params(algo) + fully_parallel = algo == 'fully_parallel' + + # __adding__ common part + recreated_state_dict = dict_list_map_outplace(lambda x: x, self.common) + + if not sharded_state_dict: + return recreated_state_dict + # TODO validate self.sharded_state_dict"] and sharded_state_dict are compatible + + sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( + sharded_state_dict + ) + # __adding__ nonpersistent part + merge(recreated_state_dict, nonpersistent_state_dict) + + sharded_part, _ = extract_sharded_base(sharded_state_dict) + + if validate_access_integrity: + with debug_time("validate_sharding_integrity", logger): + validate_sharding_integrity(determine_global_metadata(sharded_part)[1]) + + # load sharded tensors and sharded objects to sharded_part + with debug_time("_insert_sharded_data", logger): + self._insert_sharded_data( + fully_parallel, sharded_part, parallelization_group, exchange_algo + ) + with debug_time("apply_factory_merges", logger): + sharded_part = apply_factory_merges(sharded_part, sh_ten_factories) + # __adding__ sharded_part + merge(recreated_state_dict, sharded_part) + return recreated_state_dict diff --git a/megatron/core/dist_checkpointing/utils.py b/megatron/core/dist_checkpointing/utils.py index a40142f38d..b7a1e0ca29 100644 --- a/megatron/core/dist_checkpointing/utils.py +++ b/megatron/core/dist_checkpointing/utils.py @@ -1,29 +1,319 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. -from typing import Tuple +""" Helpers for manipulating sharded tensors and sharded state dicts. """ +import logging +from contextlib import contextmanager +from time import time +from typing import Dict, Optional, Tuple from .dict_utils import dict_list_map_inplace, extract_matching_values -from .mapping import LocalNonpersitentObject, ShardedStateDict, ShardedTensor, StateDict +from .mapping import ( + LocalNonpersistentObject, + ShardedBase, + ShardedObject, + ShardedStateDict, + ShardedTensor, + ShardedTensorFactory, + StateDict, +) + +# _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor +# attributes: key (str), global_offset (tuple) and flattened_range (optional tuple) +_ShardId = Tuple[str, tuple, Optional[tuple]] + + +def zip_strict(*args): + """ + Alternative to Python's builtin zip(..., strict=True) (available in 3.10+). + Apart from providing functionality in earlier versions of Python is also more verbose. + (Python's zip does not print lengths, only which iterable has finished earlier) + """ + args = [list(a) for a in args] + lens = [len(a) for a in args] + assert len(set(lens)) <= 1, f"Tried to zip iterables of unequal lengths: {lens}!" + return zip(*args) + + +def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId: + """Unique id of the sharded tensor data. + + Should yield the same value for same data replicated on different ranks. + + Args: + sharded_tensor (ShardedTensor): sharded tensor representing the data shard + + Returns (tuple): unique id of a data shard + """ + f_range = sharded_tensor.flattened_range + return ( + sharded_tensor.key, + sharded_tensor.global_offset, + None if f_range is None else (f_range.start, f_range.stop), + ) + + +def _sharded_object_id(sharded_object: ShardedObject) -> _ShardId: + """Unique id of the sharded object data. + + Should yield the same value for same data replicated on different ranks. + + Args: + sharded_object (ShardedObject): sharded object representing the data shard + + Returns (tuple): unique id of a data shard + """ + return (sharded_object.key, sharded_object.global_offset, sharded_object.global_shape) def extract_sharded_tensors( sharded_state_dict: ShardedStateDict, ) -> Tuple[ShardedStateDict, StateDict]: + """Extract a dict consisting of only ShardedTensor objects + from a given state dict with any objects. + + Args: + sharded_state_dict: state dict possibly containing ShardedTensor objects + + Returns: + Tuple[ShardedStateDict, StateDict]: tuple of: + - state dict with all ShardedTensor (keeping the original state dict structure) + - state dict with all objects other than ShardedTensor + (keeping the original state dict structure) + """ return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedTensor)) +def extract_sharded_tensors_and_factories( + sharded_state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, StateDict]: + """Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects + from a given state dict with any objects. + + Args: + sharded_state_dict: + state dict possibly containing ShardedTensor and ShardedTensorFactory objects + + Returns: + Tuple[ShardedStateDict, StateDict]: tuple of: + - state dict with all ShardedTensor and ShardedTensorFactory + (keeping the original state dict structure) + - state dict with all other objects (keeping the original state dict structure) + """ + return extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, ShardedTensorFactory)) + ) + + def extract_sharded_tensors_or_nonpersistent( sharded_state_dict: ShardedStateDict, ) -> Tuple[ShardedStateDict, StateDict]: + """Extract a dict consisting of only ShardedTensor, ShardedTensorFactory + and LocalNonpersistentObject objects from a given state dict with any objects. + + Args: + sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory + and LocalNonpersistentObject objects + + Returns: + Tuple[ShardedStateDict, StateDict]: tuple of: + - state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject + (keeping the original state dict structure) + - state dict with all other objects (keeping the original state dict structure) + """ return extract_matching_values( - sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, LocalNonpersitentObject)) + sharded_state_dict, + lambda v: isinstance(v, (ShardedTensor, LocalNonpersistentObject, ShardedTensorFactory)), + ) + + +def extract_sharded_base( + sharded_state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, StateDict]: + """Extract a dict consisting of only ShardedBase from a given state dict with any objects. + + Args: + sharded_state_dict: state dict possibly containing ShardedBase objects + + Returns: + Tuple[ShardedStateDict, StateDict]: tuple of: + - state dict with all ShardedBase objects (keeping the original state dict structure) + - state dict with all other objects (keeping the original state dict structure) + """ + return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedBase)) + + +def extract_nonpersistent( + sharded_state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, StateDict]: + """Extract a dict consisting of only LocalNonpersistentObjects from a given state dict. + + Args: + sharded_state_dict: state dict possibly containing LocalNonpersistentObjects + + Returns: + Tuple[ShardedStateDict, StateDict]: tuple of: + - state dict with all LocalNonpersistentObjects + (keeping the original state dict structure) + - state dict with all other objects (keeping the original state dict structure) + """ + + return extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, LocalNonpersistentObject) ) def add_prefix_for_sharding(sharded_state_dict: ShardedStateDict, prefix: str): + """Prepend a given prefix to all ShardedBase objects in a given state dict *in-place*. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict + prefix (str): prefix to be prepended + + Returns: + None: state dict is modified in-place + """ + def add_prefix(t): - if isinstance(t, ShardedTensor): - t.key = f'{prefix}.{t.key}' + if isinstance(t, ShardedBase): + t.key = f'{prefix}{t.key}' return t dict_list_map_inplace(add_prefix, sharded_state_dict) + + +def replace_prefix_for_sharding( + sharded_state_dict: ShardedStateDict, old_prefix: str, new_prefix: str +): + """Replaces the given prefix in *all* sharded keys in a given state dict. + + Errors out if some key does not begin with a given prefix. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in + old_prefix (str): prefix to be replaced in each key + new_prefix (str): new prefix + + Returns: + None: state dict is modified in place + """ + + def _replace_prefix(x): + if isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)): + if not x.key.startswith(old_prefix): + raise ValueError(f'Expected {x.key} to begin with prefix {old_prefix}') + x.key = f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9 + return x + + dict_list_map_inplace(_replace_prefix, sharded_state_dict) + + +def apply_prefix_mapping(sharded_state_dict: ShardedStateDict, prefix_map: Dict[str, str]): + """Replaces prefixes *only in keys matching* with one of prefixes in the map. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in + prefix_map (Dict[str, str]): + map of old->new prefixes. The first matching prefix for each key is used + + Returns: + None: state dict is modified in place + """ + + def _replace_prefixes(x): + if not isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)): + return x + for old_prefix, new_prefix in prefix_map.items(): + if x.key.startswith(old_prefix): + x.key = ( + f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9 + ) + break + return x + + dict_list_map_inplace(_replace_prefixes, sharded_state_dict) + + +fallback_logger = logging.getLogger(__name__) +__LOGGER_NAME_STACK = [] +__LOGGER_STACK = [] + + +@contextmanager +def logger_stack(name: Optional[str] = None, current_logger: Optional[logging.Logger] = None): + """Context manager for managing logger and name stack. + + Temporarily pushes a logger and/or name onto their respective stacks, allowing hierarchical + logging and contextual logger usage. Ensures the logger stack is restored afterward. + + Args: + name (str, optional): Name to add to the logger stack. Defaults to None. + current_logger (logging.Logger, optional): Logger to use. Defaults to the last logger in + the stack or a fallback if none exist. + + Yields: + Tuple[str, logging.Logger]: A tuple with the concatenated logger name stack and + the current logger for the block. + + Example: + with logger_stack("scope", logger): + logger.info("Log within 'scope'") + """ + if name: + __LOGGER_NAME_STACK.append(name) + if current_logger: + __LOGGER_STACK.append(current_logger) + last_logger = current_logger + elif __LOGGER_STACK: + last_logger = __LOGGER_STACK[-1] + else: + last_logger = fallback_logger + try: + yield ".".join(__LOGGER_NAME_STACK), last_logger + finally: + if name and __LOGGER_NAME_STACK: + __LOGGER_NAME_STACK.pop(-1) + if current_logger and __LOGGER_STACK: + __LOGGER_STACK.pop(-1) + + +@contextmanager +def debug_time( + name: str, logger: Optional[logging.Logger] = None, threshold: float = float("-inf"), level=None +): + """Simple context manager for timing functions/code blocks. + + Args: + name (str): Label describing the code being measured. + logger (logging.Logger, optional): Logger for output. Defaults to the lowest logger. + threshold (float, optional): Minimum time (seconds) to log. Skips logging if faster. + level (int, optional): Logging level. Defaults to DEBUG if `threshold` is unset; + WARNING otherwise. + """ + with logger_stack(name, logger) as (stacked_name, last_logger): + start = time() + try: + yield + finally: + result = time() - start + if result < threshold: + return + if level is None: + level = logging.DEBUG if threshold == float("-inf") else logging.WARNING + last_logger.log(level, f"{stacked_name} took {result:.4f}s") + + +def debug_msg(msg: str): + """Logs a debug message using the current logger stack. + + This function formats and logs a debug message with the current logger + and name stack, preserving context from the logger_stack context manager. + + Args: + msg (str): The message to be logged at the debug level. + + Example: + debug_msg("Checkpoint initialized") + # Logs: "scope_name Checkpoint initialized" if called within logger_stack("scope_name") + """ + with logger_stack(None, None) as (stacked_name, last_logger): + last_logger.debug(f"{stacked_name} {msg}") diff --git a/megatron/core/dist_checkpointing/validation.py b/megatron/core/dist_checkpointing/validation.py new file mode 100644 index 0000000000..546ec3547f --- /dev/null +++ b/megatron/core/dist_checkpointing/validation.py @@ -0,0 +1,560 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import logging +from collections import Counter, defaultdict +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union + +import numpy as np +import torch + +from megatron.core.dist_checkpointing import ShardedTensor +from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config +from megatron.core.dist_checkpointing.dict_utils import ( + diff, + extract_matching_values, + map_reduce, + nested_values, +) +from megatron.core.dist_checkpointing.mapping import ( + CommonStateDict, + ShardedBase, + ShardedObject, + ShardedStateDict, + is_main_replica, +) +from megatron.core.dist_checkpointing.strategies.base import ( + LoadCommonStrategy, + LoadShardedStrategy, + SaveCommonStrategy, + SaveShardedStrategy, + StrategyAction, + get_default_strategy, +) + +if TYPE_CHECKING: + from megatron.core.dist_checkpointing.serialization import CkptShardedMetadata + +logger = logging.getLogger(__name__) +# pylint: disable=line-too-long +# list of local saved/loaded ShardedBase objects +_LocalMetadata = List[Union[ShardedTensor, ShardedObject]] +# list of lists of global saved/loaded ShardedBase objects (each element corresponds to global rank) +_GlobalMetadata = List[_LocalMetadata] + + +class StrictHandling(Enum): + """Determines handling of load mismatch (non-empty "unexpected" or "missing" keys). + + Different flags carry different implications on performance and behaviour and + are divided into two groups: + - *_UNEXPECTED + - *_ALL + The first group ignores missing keys (present in the checkpoint but missing + in the sharded state dict) which is created in order to avoid inter-rank + metadata exchange. Note that the metadata exchange will happen anyway + with `load(..., validate_access_integrity=True)` flag in which case using the + `*_ALL` option is recommended as it provides a more thorough check with no + performance penalty wrt. `*_UNEXPECTED` group. + + All options except for the first one (`ASSUME_OK_UNEXPECTED`) require + extra disk access before the load in order to remove unexpected keys + from the sharded state dict requested to load. + """ + + # Relies on the underlying strategy to raise error on unexpected keys + ASSUME_OK_UNEXPECTED = 'assume_ok_unexpected' + # Logs (with WARNING level) "unexpected" keys. Missing keys are ignored. + # This is treated as a reasonable default for a "non-strict" load + LOG_UNEXPECTED = 'log_unexpected' + # Logs (with WARNING level) all mismatched keys. + LOG_ALL = 'log_all' + # Raise error on unexpected keys before load attempt. + # Gives cleaner error message than `ASSUME_OK_UNEXPECTED` but requires + # extra disk access. + RAISE_UNEXPECTED = 'raise_unexpected' + # Raise error on any mismatch. Similar to `RAISE_UNEXPECTED` but requires + # metadata exchange. + RAISE_ALL = 'raise_all' + # "Unexpected" mismatches are not reported, but returned by the `load` + # function along with the loaded state dict. Missing keys are ignored. + RETURN_UNEXPECTED = 'return_unexpected' + # All mismatches are returned along with the loaded state dict. + RETURN_ALL = 'return_all' + # Simply ignores mismatches (not recommended) + IGNORE_ALL = 'ignore_all' + + @staticmethod + def requires_explicit_ckpt_mismatch_check(val: 'StrictHandling') -> bool: + """Whether a given strict flag involves mismatch check against the checkpoint.""" + return val != StrictHandling.ASSUME_OK_UNEXPECTED + + @staticmethod + def requires_global_app_metadata(val: 'StrictHandling') -> bool: + """Whether a given strict option requires global metadata for validation.""" + return val in ( + StrictHandling.IGNORE_ALL, + StrictHandling.RAISE_ALL, + StrictHandling.RETURN_ALL, + StrictHandling.LOG_ALL, + ) + + @staticmethod + def requires_returning_mismatch_keys(val: 'StrictHandling') -> bool: + """Whether a given strict option results in extra return value from the `load` function.""" + return val in (StrictHandling.RETURN_UNEXPECTED, StrictHandling.RETURN_ALL) + + +def parse_strict_flag(strict: Union[str, StrictHandling]) -> StrictHandling: + """Parse user passed strict flag from a string to StrictHandling instance. + + Args: + strict (str, StrictHandling): strict flag to parse. If already an instance + of StrictHandling, this function is a noop. + + Returns: + StrictHandling: enum instance + """ + if isinstance(strict, StrictHandling): + return strict + try: + return StrictHandling(strict) + except (ValueError, TypeError) as e: + raise ValueError(f'Invalid strict flag: {e}') from e + + +def validate_integrity_and_strict_load( + sharded_state_dict: ShardedStateDict, + strict: StrictHandling, + validate_access_integrity: bool, + local_metadata: Optional[_LocalMetadata] = None, + global_metadata: Optional[_GlobalMetadata] = None, + ckpt_sharded_metadata: Optional['CkptShardedMetadata'] = None, +) -> Tuple[ShardedStateDict, Set[str], Set[str]]: + """Validates sharding integrity and potential mismatches with the checkpoint. + + `validate_access_integrity` controls sharding integrity check (orthogonal + to strictness checking) which verifies `sharded_state_dict` runtime completeness + (in isolation from the actual checkpoint). + + `strict` flag controls handling of mismatches between the requested + sharded state dict to load and the actual checkpoint. See `StrictHandling` + docs for details regarding flag behavior and performance implications + (disk interactions or inter-rank communication). + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to verify. + strict (StrictHandling): flag determining how to handle sharded keys mismatch. + validate_access_integrity (bool): whether to perform sharding validation. + local_metadata (_LocalMetadata, optional): local sharded state dict metadata. + Defaults to None, in which case it's determined based on `sharded_state_dict`. + global_metadata (_GlobalMetadata, optional): global sharded state dict metadata + (exchanged between ranks). Defaults to None, in which case "missing" + keys are not determined. + ckpt_sharded_metadata (CkptShardedMetadata, optional): sharded metadata + from the checkpoint. Defaults to None, which only makes sense + for the `StrictHandling.ASSUME_OK_UNEXPECTED` strict value. + + Returns: + Tuple[ShardedStateDict, Set[str], Set[str]]: tuple of: sharded state dict + without unexpected keys, missing and unexpected keys. Missing keys are equal + on all ranks, unexpected keys might differ across ranks. Additionally, + missing keys might be erroneously empty (depending on `strict` value). + """ + missing_keys, unexpected_keys = [], [] + if StrictHandling.requires_explicit_ckpt_mismatch_check(strict): + if ckpt_sharded_metadata is None: + raise CheckpointingException( + 'Cannot verify checkpoint mismatch with ckpt_sharded_metadata=None.' + ) + if local_metadata is None: + local_metadata = [ + sh_base.without_data() for sh_base in nested_values(sharded_state_dict) + ] + # We don't want to check for missing keys even if we could + _skip_missing_keys = strict in ( + StrictHandling.ASSUME_OK_UNEXPECTED, + StrictHandling.LOG_UNEXPECTED, + StrictHandling.RAISE_UNEXPECTED, + StrictHandling.RETURN_UNEXPECTED, + ) + missing_keys, unexpected_keys = _determine_missing_and_unexpected_keys( + ckpt_sharded_metadata, local_metadata, None if _skip_missing_keys else global_metadata + ) + + sharded_state_dict = adjust_non_strict_load(sharded_state_dict, unexpected_keys) + + if strict == StrictHandling.IGNORE_ALL: + missing_keys, unexpected_keys = [], [] + elif strict in (StrictHandling.RAISE_UNEXPECTED, StrictHandling.RAISE_ALL): + maybe_report_missing_and_unexpected_keys(missing_keys, unexpected_keys, True) + elif strict in (StrictHandling.LOG_UNEXPECTED, StrictHandling.LOG_ALL): + maybe_report_missing_and_unexpected_keys(missing_keys, unexpected_keys, False) + + if validate_access_integrity: + if global_metadata is None: + raise CheckpointingException( + 'Cannot check sharding intergrity without global_metadata (None).' + ) + validate_sharding_integrity(global_metadata) + + return sharded_state_dict, missing_keys, unexpected_keys + + +def verify_checkpoint_and_load_strategy( + checkpoint_dir: str, + sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None, + common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None, +) -> Tuple[LoadShardedStrategy, LoadCommonStrategy]: + """Verifies if checkpoint metadata exists and matches given strategies. + + If no strategies are passed, they are determined based on the checkpoint metadata. + + Args: + checkpoint_dir (str): checkpoint directory + sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): sharded load strategy to be verified + if compatible with the checkpoint content. If None, the default sharded load strategy + for the checkpoint backend will be returned. + common_strategy (LoadCommonStrategy, Tuple[str, int], optional): common load strategy to be verified + if compatible with the checkpoint content. If None, the default common load strategy + for the checkpoint backend will be returned. + """ + if not Path(checkpoint_dir).exists(): + raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist') + + saved_config = maybe_load_config(checkpoint_dir) + if saved_config is None: + raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint') + + if sharded_strategy is None: + sharded_strategy = get_default_strategy( + StrategyAction.LOAD_SHARDED, + saved_config.sharded_backend, + saved_config.sharded_backend_version, + ) + elif isinstance(sharded_strategy, tuple): + sharded_strategy = get_default_strategy(StrategyAction.LOAD_SHARDED, *sharded_strategy) + + if common_strategy is None: + common_strategy = get_default_strategy( + StrategyAction.LOAD_COMMON, + saved_config.common_backend, + saved_config.common_backend_version, + ) + elif isinstance(common_strategy, tuple): + sharded_strategy = get_default_strategy(StrategyAction.LOAD_COMMON, *common_strategy) + + sharded_strategy.check_backend_compatibility(saved_config.sharded_backend) + sharded_strategy.check_version_compatibility(saved_config.sharded_backend_version) + common_strategy.check_backend_compatibility(saved_config.common_backend) + common_strategy.check_version_compatibility(saved_config.common_backend_version) + return sharded_strategy, common_strategy + + +def adjust_non_strict_load( + sharded_state_dict: ShardedStateDict, sharded_keys_to_remove: Set[str] +) -> ShardedStateDict: + """Adjusts sharded state dict removing keys not existing in the checkpoint. + + Args: + sharded_state_dict (ShardedStateDict): sharded state dict to modify + sharded_keys_to_remove (Set[str]): keys to remove from the state dict + + Returns: + ShardedStateDict: state dict without ShardedBase objects with specified keys + """ + + def is_unexpected_key(x: ShardedBase): + assert isinstance(x, ShardedBase), f'Unexpected type {type(x)}' + return x.key in sharded_keys_to_remove + + _, sharded_state_dict = extract_matching_values(sharded_state_dict, is_unexpected_key) + return sharded_state_dict + + +def _determine_missing_and_unexpected_keys( + ckpt_sharded_metadata: 'CkptShardedMetadata', + local_metadata: _LocalMetadata, + global_metadata: Optional[_GlobalMetadata] = None, +) -> Tuple[Set[str], Set[str]]: + """Determines load mismatches based on metadata. + + There is an asymmetry between "unexpected" and "missing" keys. + Unexpected keys can be determined based only on local metadata. + Missing keys must be based on global metadata, since other ranks might access + different keys than the current rank. + In consequence, the return value of this function is different on each rank: + "missing_keys" are equal, but "unexpected_keys" might differ across ranks. + + Args: + ckpt_sharded_metadata (CkptShardedMetadata): sharded state dict (without data) + constructed based on the checkpoint content + local_metadata (_LocalMetadata): list of local ShardedBase objects + requested to be loaded by this rank + global_metadata (_GlobalMetadata, optional): list of global ShardedBase objects + requested to be loaded by all ranks. Defaults to None, in which case + returned "missing" keys are empty. + + Returns: + Tuple[Set[str], Set[str]]: missing and unexpected keys. Missing keys are equal + on all ranks, unexpected keys might differ across ranks. If passed + `global_metadata` is empty, returned missing keys are empty as well. + + """ + local_accessed_keys = set(sh_base.key for sh_base in local_metadata) + ckpt_keys = set(sh_base.key for sh_base in ckpt_sharded_metadata.values()) + unexpected_keys = local_accessed_keys - ckpt_keys + if global_metadata is not None: + global_accessed_keys = set( + sh_base.key for rank_metadata in global_metadata for sh_base in rank_metadata + ) + missing_keys = ckpt_keys - global_accessed_keys + else: + missing_keys = set() + + if missing_keys: + logger.debug(f'Dist ckpt load missing keys: {missing_keys}') + if unexpected_keys: + logger.debug(f'Dist ckpt load unexpected keys: {unexpected_keys}') + + return missing_keys, unexpected_keys + + +def maybe_report_missing_and_unexpected_keys( + missing_keys: Set[str], unexpected_keys: Set[str], raise_error: bool = True +) -> None: + """Raises or logs an error in case missing or unexpected keys are non-empty. + + Args: + missing_keys (Set[str]): missing keys in the state dict + unexpected_keys (Set[str]): unexpected keys in the state dict + raise_error: If True, raises error on mismatch. Otherwise, logs mismatch + with WARNING level. + + Returns: + None + + Raises: + CheckpointingException: if `raise_error` is True and at least one of + `missing_keys` or `unexpected_keys` are non-empty. + """ + if not missing_keys and not unexpected_keys: + return + missing_title_msg = ( + f'Some keys found in the checkpoint are missing in the provided sharded state dict. ' + ) + missing_body_msg = f'Missing keys (for all ranks): {missing_keys}. ' + unexpected_title_msg = f'Unexpected keys (not found in the checkpoint) encountered in the provided sharded state dict. ' + unexpected_body_msg = f'Unexpected keys (for this rank): {unexpected_keys}. ' + error_msg = '' + if missing_keys: + error_msg += missing_title_msg + if unexpected_keys: + error_msg += unexpected_title_msg + + error_msg += '\n' + if missing_keys: + error_msg += missing_body_msg + if unexpected_keys: + error_msg += unexpected_body_msg + + if raise_error: + raise CheckpointingException(error_msg) + else: + logger.warning(error_msg) + + +def _validate_common_state_dict(common_state_dict: CommonStateDict) -> None: + """Validate consistancy across ranks for the common state dict + + We save the common state dict only on rank 0. We validate to make sure that the common dict is consistant across ranks before saving. + + Args: + common_state_dict: The common state dict present in all ransk + """ + + # Gather the common state dict across ranks onto rank 0 for comparison + rank = torch.distributed.get_rank() + other_rank_state_dicts = [None] * torch.distributed.get_world_size() if rank == 0 else None + torch.distributed.gather_object(common_state_dict, other_rank_state_dicts) + common_state_dict_diff = {} + if rank == 0: + main_rank_state_dict = common_state_dict + for rank, rank_state_dict in enumerate(other_rank_state_dicts[1:], 1): + only_left, only_right, mismatch = diff(main_rank_state_dict, rank_state_dict) + if only_left or only_right or mismatch: + common_state_dict_diff[rank] = (only_left, only_right, mismatch) + + if len(common_state_dict_diff) != 0: + logger.warning( + f'There is difference in the common state dict in different ranks. The differences are {common_state_dict_diff}' + ) + + +def validate_sharding_integrity( + global_metadata: _GlobalMetadata, common_state_dict: CommonStateDict = None +) -> None: + """Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding. + + Local ShardedTensors and ShardedObject metadata is exchanged with `torch.distributed.all_gather_object` + and then process with global rank 0 checks if main replicas of the shards: + - cover the whole global tensors + - don't overlap + + Args: + global_metadata (_GlobalMetadata): ShardedTensor and ShardedObject objects from all ranks. + common_state_dict (CommonStateDict): The common state dict stored by rank 0 + + Returns: + None + + Raises: + CheckpointingException for invalid access pattern + """ + + if common_state_dict is not None: + _validate_common_state_dict(common_state_dict) + + if torch.distributed.get_rank() != 0: + return + + key_shardings = defaultdict(list) + for rank, rank_shardings in enumerate(global_metadata): + for sharding in rank_shardings: + key_shardings[sharding.key].append((rank, sharding)) + for key, shardings in key_shardings.items(): + if isinstance(shardings[0][1], ShardedObject): + _validate_objects_for_key(shardings) + else: + _validate_sharding_for_key(shardings) + + +def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]): + some_rank_shard = rank_sharding[0][1] + global_shape = some_rank_shard.global_shape + local_shape = some_rank_shard.local_shape + dtype = some_rank_shard.dtype + has_flattened_range = some_rank_shard.flattened_range is not None + for rank, sharding in rank_sharding: + assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard) + assert sharding.global_shape == global_shape, ( + sharding.global_shape, + global_shape, + some_rank_shard, + ) + assert sharding.local_shape == local_shape, ( + sharding.local_shape, + local_shape, + some_rank_shard, + ) + assert (sharding.flattened_range is not None) == has_flattened_range, ( + (sharding.flattened_range is not None), + has_flattened_range, + some_rank_shard, + ) + + shard_access_cnt = _compute_shards_access(rank_sharding) + if has_flattened_range: + map_reduce( + rank_sharding, + lambda x: x[1].global_offset, + lambda x: x[1], + _validate_sharding_for_key_flattened, + ) + # For each shard with at least 1 flattened tensor in it, the above + # `_validate_sharding_for_key_flattened` ensure a correct consistent pattern + # The only thing that can go wrong at this point is that some shard don't have + # *any* representatives which will be checked later by comparing `shard_access_cnt == 1` + shard_access_cnt = torch.minimum(shard_access_cnt, torch.tensor([1])) + if not torch.all(shard_access_cnt == 1): + raise CheckpointingException( + f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}' + ) + + +def _compute_shards_access(rank_sharding): + shard_access_cnt = torch.zeros( + rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu' + ) + for rank, sharding in rank_sharding: + if is_main_replica(sharding.replica_id): + shard_access_cnt[sharding.local_chunk_offset_in_global()] += 1 + return shard_access_cnt + + +def _validate_sharding_for_key_flattened(tensors_by_shard): + all_slices = [] + local_shape = tensors_by_shard[0].local_shape + for sharding in tensors_by_shard: + assert sharding.local_shape == local_shape + sharding: ShardedTensor + if not is_main_replica(sharding.replica_id): + continue + + all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop)) + + starts, stops = map(np.asarray, zip(*sorted(all_slices))) + expected_size = np.product(local_shape) + if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]): + raise CheckpointingException( + f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}' + ) + + +def _validate_objects_for_key(sharded_objects: List[ShardedObject]): + """Ensure uniqueness of saved objects.""" + unique_keys = [ + sh_obj.unique_key for _, sh_obj in sharded_objects if is_main_replica(sh_obj.replica_id) + ] + if len(unique_keys) != len(set(unique_keys)): + duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1} + logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}') + raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}') + expected_shard_num = np.prod(sharded_objects[0][1].global_shape) + if len(unique_keys) != expected_shard_num: + err_msg = f'Invalid access pattern: {expected_shard_num - len(unique_keys)} ShardedObject are missing.' + logger.error(f'{err_msg} Existing shards: {unique_keys}') + raise CheckpointingException(err_msg) + + +def determine_global_metadata( + sharded_state_dict: ShardedStateDict, +) -> Tuple[_LocalMetadata, _GlobalMetadata]: + """Exchanges local metadata with `all_gather_object` to determine global metadata. + + Args: + sharded_state_dict (ShardedStateDict): local sharded state dict + + Returns: + Tuple[_LocalMetadata, _GlobalMetadata]: local and global ShardedBase objects with stripped data + """ + local_metadata = [ten.without_data() for ten in nested_values(sharded_state_dict)] + global_metadata = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(global_metadata, local_metadata) + return local_metadata, global_metadata + + +def validate_sharded_objects_handling( + sharded_strategy: Union[SaveShardedStrategy, LoadShardedStrategy], + common_strategy: Union[SaveCommonStrategy, LoadCommonStrategy], +) -> None: + """Checks if either of the passed strategies can handle sharded objects. + + Args: + sharded_strategy (Union[SaveShardedStrategy, LoadShardedStrategy]): sharded strategy used for saving/loading + common_strategy (Union[SaveCommonStrategy, LoadCommonStrategy]): common strategy used for saving/loading + + Returns: + None + + Raises: + CheckpointingException: if both strategies can't handle ShardedObjects + """ + if ( + not sharded_strategy.can_handle_sharded_objects + and not common_strategy.can_handle_sharded_objects + ): + raise CheckpointingException( + f'Either sharded strategy or common strategy must implement ShardedObjects handling.' + f' Both {sharded_strategy} and {common_strategy} specify can_handle_sharded_objects=False' + ) diff --git a/megatron/core/distributed/README.md b/megatron/core/distributed/README.md new file mode 100644 index 0000000000..c4a7528441 --- /dev/null +++ b/megatron/core/distributed/README.md @@ -0,0 +1,11 @@ +## How to use pytorch FSDP2? + +Add these flag to enable Torch FSDP2. + +``` +--use-torch-fsdp2 +--no-gradient-accumulation-fusion +--ckpt-format torch_dist +``` + +It is worth noting that CUDA_MAX_CONNECTIONS=1 should not be enabled to ensure that the communication of FSDP and the computation on the primary stream can be fully parallelized. diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py index 34c7209a27..9dbf83c80d 100644 --- a/megatron/core/distributed/__init__.py +++ b/megatron/core/distributed/__init__.py @@ -1,2 +1,8 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from packaging.version import Version + from .distributed_data_parallel import DistributedDataParallel +from .distributed_data_parallel_config import DistributedDataParallelConfig from .finalize_model_grads import finalize_model_grads +from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel diff --git a/megatron/core/distributed/data_parallel_base.py b/megatron/core/distributed/data_parallel_base.py new file mode 100644 index 0000000000..aed576a7a3 --- /dev/null +++ b/megatron/core/distributed/data_parallel_base.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from contextlib import contextmanager + +import torch + +from ..transformer.module import MegatronModule +from ..transformer.transformer_config import TransformerConfig + + +class _BaseDataParallel(MegatronModule): + """A template class for DistributedDataParallel implementations.""" + + def __init__(self, config: TransformerConfig, module: torch.nn.Module): + super().__init__(config=config) + self.module = module + + def forward(self, *inputs, **kwargs): + """ + Calls the wrapped module's forward() method. + """ + return self.module(*inputs, **kwargs) + + @contextmanager + def no_sync(self): + """ + Context manager that turns off gradient synchronization. + """ + try: + yield + finally: + pass + + def start_grad_sync(self, *unused): + """ + Initiates grad sync (all-reduce or reduce-scatter) communication operations + for all model gradients. + + When overlap_grad_reduce is set to True, dispatches asynchronous communication + calls. When overlap_grad_reduce is set to False, calls synchronous + communication ops. + """ + pass + + def scale_gradients(self, scaling_factor: float) -> None: + """Scale all gradients inside the buffers by `scaling_factor`.""" + pass + + def finish_grad_sync(self): + """ + Finishes grad sync (all-reduce or reduce-scatter) communication operations + for all model gradients. + + When overlap_grad_reduce is set to True, waits for asynchronous communication + calls to complete. When overlap_grad_reduce is set to False, calls synchronous + communication ops. + """ + pass + + def zero_grad_buffer(self): + """ + Zeros out all grad buffers. Needs to be called at the beginning of each + training iteration. + """ + pass + + def broadcast_params(self): + """ + Syncs parameters across all DP ranks. + """ + pass + + def state_dict(self, prefix='', keep_vars=False): + """ + Returns a dictionary containing references to the whole state of the + wrapped module. + + Both parameters and persistent buffers (e.g. running averages) are included. + Keys are corresponding parameter and buffer names. Parameters and buffers + set to None are not included. + """ + return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """ + Returns wrapped module's state_dict for checkpoint saving. + """ + return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) + + def load_state_dict(self, state_dict, strict=True): + """ + Copies parameters and buffers from state_dict into the wrapped module and its + descendants. If strict is True, then the keys of state_dict must exactly match + the keys returned by this module’s state_dict() function. + """ + self.module.load_state_dict(state_dict, strict=strict) diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 4c2c2ee525..ea08db6c12 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -1,18 +1,23 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import math +import logging from contextlib import contextmanager -from typing import Dict import torch from .. import parallel_state -from ..transformer.module import MegatronModule +from ..config_logger import has_config_logger_enabled, log_config_to_disk +from ..transformer.cuda_graphs import is_graph_capturing from ..transformer.transformer_config import TransformerConfig -from .grad_buffer import GradBuffer +from ..utils import is_float8tensor, log_single_rank +from .data_parallel_base import _BaseDataParallel +from .distributed_data_parallel_config import DistributedDataParallelConfig +from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets +logger = logging.getLogger(__name__) -class DistributedDataParallel(MegatronModule): + +class DistributedDataParallel(_BaseDataParallel): """ DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping communication with backprop computation by breaking up full model's gradients into smaller @@ -20,17 +25,10 @@ class DistributedDataParallel(MegatronModule): also provides the option to do the gradient accumulation in a type other than the param type (e.g., fp32 for a bf16 model). - Arguments: + Args: config: Transformer config object. + ddp_config: DistributedDataParallel config object. module: Underlying model. - data_parallel_group: Data-parallel process group. - accumulate_allreduce_grads_in_fp32: If true, do the gradient accumulation and - communication in fp32. - overlap_grad_reduce: If true, overlap communication with backprop computation by - breaking up grads into buckets. If false, single synchronous communication call - is used instead. - use_distributed_optimizer: If true, issue reduce-scatter communication calls as part - of distributed optimizer. If false, issue all-reduce communication calls. disable_bucketing: If true, force assign all parameters to a single bucket. If false, use standard bucketing policy: assign parameters to smaller buckets and all-reduce per bucket _if_ overlap_grad_reduce is True and pp_rank is 0. @@ -40,108 +38,240 @@ class DistributedDataParallel(MegatronModule): def __init__( self, config: TransformerConfig, + ddp_config: DistributedDataParallelConfig, module: torch.nn.Module, - data_parallel_group: torch.distributed.ProcessGroup, - accumulate_allreduce_grads_in_fp32: bool, - overlap_grad_reduce: bool, - use_distributed_optimizer: bool, disable_bucketing: bool = False, - bucket_size: int = 40000000, ): - super().__init__(config=config) + super().__init__(config=config, module=module) + if has_config_logger_enabled(config): + log_config_to_disk(config, locals(), prefix=type(self).__name__) + self.module = module + # If bucket_size is not provided as an input, use sane default. + # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL + # ring-reduce implementations are large enough to remain bandwidth-bound rather than + # latency-bound. + if ddp_config.bucket_size is None: + ddp_config.bucket_size = max( + 40000000, 1000000 * parallel_state.get_data_parallel_world_size() + ) # Set bucket_size to infinity if overlap_grad_reduce is False. - self.overlap_grad_reduce = overlap_grad_reduce - self.use_distributed_optimizer = use_distributed_optimizer - - # Turn off bucketing if overlap_grad_reduce is False, if we are on a pipeline stage - # that is not the first (since data-parallel communication on these stages is not on - # the critical path), or if disable_bucketing is True (e.g., we might not want to - # break up model parameters into buckets for model chunks after the first - # in the interleaved schedule). - if not self.overlap_grad_reduce: - bucket_size = None + if not ddp_config.overlap_grad_reduce: + ddp_config.bucket_size = None + + self.ddp_config = ddp_config + log_single_rank( + logger, + logging.INFO, + f'Setting up DistributedDataParallel with config {self.ddp_config}', + ) + + # Turn off bucketing if we are on a pipeline stage that is not the first (since + # data-parallel communication on these stages is not on the critical path), or if + # disable_bucketing is True (e.g., we might not want to break up model parameters + # into buckets for model chunks after the first in the interleaved schedule). + self.bucket_size = self.ddp_config.bucket_size if parallel_state.get_pipeline_model_parallel_rank() > 0: - bucket_size = None + self.bucket_size = None if disable_bucketing: - bucket_size = None - self.bucket_size = bucket_size + self.bucket_size = None - self.module = module - self.grad_buffers = {} - self.expert_grads = [] - self.grad_buffer_param_index_map = {} - self.param_to_grad_buffer = {} + self.param_to_bucket_group = {} # Group parameters by their gradient type. - grad_dtype_to_params = {} - grad_dtype_to_numel = {} param_to_name = {} + dense_params = [] + expert_parallel_params = [] + self.params_with_grad = [] for name, param in self.module.named_parameters(): - if param.requires_grad and getattr(param, 'allreduce', True): - param.grad_added_to_main_grad = False - param_to_name[param] = name - dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype - - params = grad_dtype_to_params.get(dtype, []) + if not param.requires_grad: + continue + + # Track params with grad to enable direct setting + # of param.grad_added_to_main_grad + self.params_with_grad.append(param) + + param.grad_added_to_main_grad = False + param_to_name[param] = name + + if getattr(param, 'allreduce', True): + dense_params.append(param) + else: + expert_parallel_params.append(param) + + def _allocate_buffers_for_parameters( + input_params, data_parallel_group, gradient_scaling_factor + ): + param_and_grad_dtype_to_params = {} + param_and_grad_dtype_to_offsets = {} + param_and_grad_dtype_to_indices = {} + + # Group parameters by their gradient type. + for param in input_params: + assert param.requires_grad + + param_dtype = param.dtype + if is_float8tensor(param): + # Currently TE's Float8Tensor is a wrapper of torch.Tensor. It has a "fake" + # dtype (usually a higher precision dtype such as bfloat16), but its actual + # data is stored in the form of a torch uint8 tensor within the Float8Tensor's + # ".data" attribute. Therefore, when creating the param buffer for fp8 params, + # it is necessary to use torch.uint8, not the "fake" dtype got from + # "param.dtype". + param_dtype = torch.uint8 + grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype + + params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), []) params.append(param) - grad_dtype_to_params[dtype] = params + param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params + + # Get the index of each param among the params with same dtype, if a param is fp8, + # use its "fake" high precision dtype to find which params have same dtype with it. + # For example: + # Case 1: + # params = [p1(bf16), p2(bf16), p3(bf16), p4(bf16)] + # param_and_grad_dtype_to_indices = { + # (torch.bfloat16, torch.float32): [0, 1, 2, 3], + # } + # Case 2: + # params = [p1(bf16), p2(fp8), p3(fp8), p4(bf16)] + # param_and_grad_dtype_to_indices = { + # (torch.bfloat16, torch.float32): [0, 3], + # (torch.uint8, torch.float32): [1, 2], + # } + # We need these indices to load a non-native-fp8 checkpoint in native-fp8 mode. + offset = param_and_grad_dtype_to_offsets.get((param.dtype, grad_dtype), 0) + param_and_grad_dtype_to_offsets[(param.dtype, grad_dtype)] = offset + 1 + indices = param_and_grad_dtype_to_indices.get((param_dtype, grad_dtype), []) + indices.append(offset) + param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)] = indices + + if not config.calculate_per_token_loss: + target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size( + with_context_parallel=True + ) + if self.ddp_config.average_in_collective: + if self.ddp_config.num_distributed_optimizer_instances == 1: + # Collective is averaging gradients in collective with data_parallel_group. + assert ( + gradient_scaling_factor + / torch.distributed.get_world_size(group=data_parallel_group) + == target_gradient_scaling_factor + ) + else: + # For non-expert parameters, gradient_scaling_factor is 1. + # For expert parameters, gradient_scaling_factor is 1/ep_size. + assert (gradient_scaling_factor == 1) or ( + gradient_scaling_factor + == (1.0 / parallel_state.get_expert_model_parallel_world_size()) + ) + else: + assert gradient_scaling_factor == target_gradient_scaling_factor + + # Allocate the grad buffers and map the grads. + buffers = [] + for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items(): + buffers.append( + _ParamAndGradBuffer( + self.ddp_config, + param_dtype, + grad_dtype, + params, + data_parallel_group, + self.bucket_size, + param_to_name, + gradient_scaling_factor, + param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)], + ) + ) - # Calculate number of elements per dtype. - grad_dtype_to_numel[dtype] = ( - grad_dtype_to_numel.get(dtype, 0) + param.data.nelement() + # In some scenarios, we want to put buckets from different buffers into a group so that + # their communication can be aggregated. For example, when there are both fp8 buffers + # and bf16 buffers in the model and vpp is enabled, each model chunk will have an fp8 + # bucket and a bf16 bucket, which doubles the number of communication kernels, and + # because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back + # communications will prevent the overlap of the communication kernels with computation + # kernels. + # If bucketing is explicitly disabled, then put all buckets in a buffer into a single + # bucket group. + bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing) + + if self.ddp_config.num_distributed_optimizer_instances > 1: + assert ( + self.ddp_config.use_distributed_optimizer + ), 'Partial DistOpt cannot be used without DistOpt' + communication_stream = torch.cuda.Stream(device=torch.cuda.current_device()) + for bucket_group in bucket_groups: + bucket_group.inter_distributed_optimizer_instance_group = ( + parallel_state.get_inter_partial_data_parallel_group() + ) + bucket_group.communication_stream = communication_stream + + # Set `next_param_gather_bucket_group` for different bucket groups by iterating through + # buckets in reverse order (since all-gathers happen in reverse order of buckets). + if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather: + num_bucket_groups = len(bucket_groups) + for i in range(1, num_bucket_groups): + bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = ( + bucket_groups[num_bucket_groups - i - 1] + ) + + # Create map from param to bucket group, used in pre_hook. + for bucket_group in bucket_groups: + for bucket in bucket_group.buckets: + for param in bucket.params_list: + self.param_to_bucket_group[param] = bucket_group + + return buffers, bucket_groups + + if config.calculate_per_token_loss: + gradient_scaling_factor = 1.0 + expert_gradient_scaling_factor = 1.0 + else: + if self.ddp_config.average_in_collective: + gradient_scaling_factor = 1.0 + expert_gradient_scaling_factor = ( + 1.0 / parallel_state.get_expert_model_parallel_world_size() + ) + else: + data_parallel_world_size = parallel_state.get_data_parallel_world_size( + with_context_parallel=True ) - # Allocate the grad buffers and map the grads. - # The grad buffer under the hood creates buckets as appropriate based on bucket_size. - data_parallel_world_size = torch.distributed.get_world_size(group=data_parallel_group) - for dtype, params in grad_dtype_to_params.items(): - # Pad so size is divisible by the data parallel size. - numel = grad_dtype_to_numel[dtype] - numel_padded = ( - int(math.ceil(numel / data_parallel_world_size)) * data_parallel_world_size + gradient_scaling_factor = 1.0 / data_parallel_world_size + expert_gradient_scaling_factor = 1.0 / data_parallel_world_size + + # Allocate the param+grad buffers for dense params' grads. + self.buffers, self.bucket_groups = _allocate_buffers_for_parameters( + dense_params, + parallel_state.get_data_parallel_group( + with_context_parallel=True, partial_data_parallel=True + ), + gradient_scaling_factor=gradient_scaling_factor, + ) + + # Allocate separate param+grad buffers for expert parallel params' grads. + self.expert_parallel_buffers, self.expert_parallel_bucket_groups = ( + _allocate_buffers_for_parameters( + expert_parallel_params, + parallel_state.get_expert_data_parallel_group(), + gradient_scaling_factor=expert_gradient_scaling_factor, ) + ) - self.grad_buffers[dtype] = GradBuffer( - numel, - numel_padded, - dtype, - params, - data_parallel_group, - bucket_size, - param_to_name, - self.overlap_grad_reduce, - self.use_distributed_optimizer, - ) + # Delete references to weight_tensor if they exist since we don't want two parameter copies + # if we re-mapped parameters (which happens when we use the distributed optimizer). + # This is a temporary workaround around a TE bug that is fixed with + # https://github.com/NVIDIA/TransformerEngine/pull/719. + if self.ddp_config.use_distributed_optimizer: - # Parameters are laid out in the corresponding grad_buffer in reverse - # order, so count indices from the back. - index = grad_dtype_to_numel[dtype] - for param in params: - self.param_to_grad_buffer[param] = self.grad_buffers[dtype] - if dtype not in self.grad_buffer_param_index_map: - self.grad_buffer_param_index_map[dtype] = {} - - index -= param.data.nelement() - # Store the indices / bucket of each param. - self.grad_buffer_param_index_map[dtype][param] = ( - index, - index + param.data.nelement(), - self.grad_buffers[dtype].param_to_bucket_index[param], - ) + @torch.no_grad() + def unmap_weight_tensor(m): + if hasattr(m, 'weight_tensor'): + m.weight_tensor = None - # Allocate discreate buffer for MoE params' grads - for param in self.module.parameters(): - if param.requires_grad and not getattr(param, 'allreduce', True): - dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype - param.main_grad = torch.zeros( - param.data.shape, - dtype=dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - self.expert_grads.append(param.main_grad) + self.module.apply(unmap_weight_tensor) # Register backward hook. # Accumulation function for the gradients need to be stored so they @@ -153,48 +283,144 @@ def __init__( param_tmp = param.expand_as(param) # Get the gradient accumulator function. grad_acc = param_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(self._make_param_hook(param, self.param_to_grad_buffer)) + grad_acc.register_hook(self._make_backward_post_hook(param)) self.grad_accs.append(grad_acc) - def forward(self, *inputs, **kwargs): + self.use_forward_hook = ( + self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather + ) + self.remove_forward_pre_hook_handles = {} + if self.use_forward_hook: + self.enable_forward_pre_hook() + self.overlap_param_gather_with_optimizer_step = False + + def enable_forward_pre_hook(self): """ - Calls the wrapped module's forward() method. + Enable forward pre-hooks needed for param all-gather overlap with forward compute. """ - return self.module(*inputs, **kwargs) + assert self.use_forward_hook + assert len(self.remove_forward_pre_hook_handles) == 0 + # Register forward pre-hook for all sub-modules. + for module in self.module.modules(): + self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook( + self._make_forward_pre_hook() + ) - def _make_param_hook( - self, param: torch.nn.Parameter, param_to_grad_buffer: Dict[torch.nn.Parameter, GradBuffer] - ): + def disable_forward_pre_hook(self, param_sync: bool = True): """ - Creates the all-reduce / reduce-scatter hook for backprop. + Disable forward pre-hooks needed for param all-gather overlap with forward compute. + Skip synchronous param all-gather if `param_sync` is False. + """ + assert self.use_forward_hook + # De-register forward pre-hook for all sub-modules. + for module in self.module.modules(): + assert self.remove_forward_pre_hook_handles[module] is not None + self.remove_forward_pre_hook_handles[module].remove() + del self.remove_forward_pre_hook_handles[module] + assert len(self.remove_forward_pre_hook_handles) == 0 + + # Force synchronize parameters. + if param_sync: + self.start_param_sync(force_sync=True) + + def _make_forward_pre_hook(self): + """ + Create a forward pre-hook to wait on all-gather handles when necessary (i.e., + when a module uses a parameter in a bucket with a still incomplete all-gather). """ - def param_hook(*unused): - if param.requires_grad: - if self.overlap_grad_reduce: + def hook(module, *unused): + assert ( + self.use_forward_hook + ), "Should use pre-hook only when overlap_param_gather is True" + + if is_graph_capturing(): + return + + # Make sure all parameters in this module have been all-gathered as necessary. + for param in module.parameters(recurse=False): + # Skip parameters without an associated buffer (such parameters have a + # .requires_grad field equal to False). + if param not in self.param_to_bucket_group: + continue + assert param.requires_grad + + # If aligning param all-gather across pipeline stages, all-gather is dispatched + # by start_param_sync calls in core/pipeline_parallelism/schedules.py. + # If overlapping param all-gather with optimizer step, then all-gather has + # already been dispatched in optimizer step. + skip_next_bucket_dispatch = ( + self.ddp_config.align_param_gather + or self.overlap_param_gather_with_optimizer_step + ) + self.param_to_bucket_group[param].finish_param_sync( + skip_next_bucket_dispatch=skip_next_bucket_dispatch + ) + + return hook + + def _make_backward_post_hook(self, param: torch.nn.Parameter): + """ + Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when + ready (i.e., when all grads in a bucket have been computed in all microbatches + in a batch). + """ + + def hook(*unused): + if is_graph_capturing(): + return + + if param in self.param_to_bucket_group: + assert param.requires_grad + if self.ddp_config.overlap_grad_reduce: assert ( param.grad is not None ), 'param.grad being None is not safe when overlap_grad_reduce is True' - if param.grad is not None and not param.grad_added_to_main_grad: + if param.grad is not None and ( + not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False) + ): param.main_grad.add_(param.grad.data) param.grad = None - if self.overlap_grad_reduce: - param_to_grad_buffer[param].register_grad_ready(param) - return param_hook + if self.ddp_config.overlap_grad_reduce: + self.param_to_bucket_group[param].register_grad_ready(param) + + return hook @contextmanager def no_sync(self): """ Context manager that turns off gradient synchronization. """ - for grad_buffer in self.grad_buffers.values(): - grad_buffer.is_last_microbatch = False + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.is_last_microbatch = False try: yield finally: - for grad_buffer in self.grad_buffers.values(): - grad_buffer.is_last_microbatch = True + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.is_last_microbatch = True + + def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False): + """ + Initiates param sync (all-gather) communication operations for all model parameters. + + By default, when overlap_param_gather is set to True, dispatches asynchronous communication + calls; when overlap_param_gather is set to False, calls synchronous communication + ops. Can override this default behavior using flags below. + + Args: + force_sync (bool, optional): force synchronous collective regardless of + other settings. + force_dispatch (bool, optional): force dispatch regardless of other settings. + """ + if not force_sync: + # If overlapping param AG with optimizer step, AG should not be dispatched again + # in forward_backward_step. + if self.overlap_param_gather_with_optimizer_step and not force_dispatch: + return + + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.start_param_sync(force_sync=force_sync) def start_grad_sync(self, *unused): """ @@ -205,8 +431,8 @@ def start_grad_sync(self, *unused): calls. When overlap_grad_reduce is set to False, calls synchronous communication ops. """ - for grad_buffer in self.grad_buffers.values(): - grad_buffer.start_grad_sync() + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.start_grad_sync() def finish_grad_sync(self): """ @@ -217,54 +443,41 @@ def finish_grad_sync(self): calls to complete. When overlap_grad_reduce is set to False, calls synchronous communication ops. """ - for grad_buffer in self.grad_buffers.values(): - grad_buffer.finish_grad_sync() + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.finish_grad_sync() + + def scale_gradients(self, scaling_factor: float): + """Scale all gradients inside the buffers by `scaling_factor`.""" + for buffer in self.buffers + self.expert_parallel_buffers: + buffer.scale_gradients(scaling_factor) def zero_grad_buffer(self): """ Zeros out all grad buffers. Needs to be called at the beginning of each training iteration. """ - for param in self.module.parameters(): - if param.requires_grad: - param.grad_added_to_main_grad = False - for grad_buffer in self.grad_buffers.values(): - grad_buffer.reset() - for expert_grad in self.expert_grads: - expert_grad.zero_() + for param in self.params_with_grad: + param.grad_added_to_main_grad = False + for buffer in self.buffers + self.expert_parallel_buffers: + buffer.reset() + for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups: + bucket_group.reset() def broadcast_params(self): """ Syncs parameters across all DP ranks. """ for param in self.module.parameters(): + is_expert_parallel = not getattr(param, 'allreduce', True) + + if is_expert_parallel: + data_parallel_group = parallel_state.get_expert_data_parallel_group() + else: + data_parallel_group = parallel_state.get_data_parallel_group( + with_context_parallel=True, partial_data_parallel=True + ) torch.distributed.broadcast( param.data, - src=parallel_state.get_data_parallel_src_rank(with_context_parallel=True), - group=parallel_state.get_data_parallel_group(with_context_parallel=True), + src=torch.distributed.get_global_rank(data_parallel_group, 0), + group=data_parallel_group, ) - - def state_dict(self, prefix='', keep_vars=False): - """ - Returns a dictionary containing references to the whole state of the - wrapped module. - - Both parameters and persistent buffers (e.g. running averages) are included. - Keys are corresponding parameter and buffer names. Parameters and buffers - set to None are not included. - """ - return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """ - Returns wrapped module's state_dict for checkpoint saving. - """ - return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) - - def load_state_dict(self, state_dict, strict=True): - """ - Copies parameters and buffers from state_dict into the wrapped module and its - descendants. If strict is True, then the keys of state_dict must exactly match - the keys returned by this module’s state_dict() function. - """ - self.module.load_state_dict(state_dict, strict=strict) diff --git a/megatron/core/distributed/distributed_data_parallel_config.py b/megatron/core/distributed/distributed_data_parallel_config.py new file mode 100644 index 0000000000..9e9058d4be --- /dev/null +++ b/megatron/core/distributed/distributed_data_parallel_config.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class DistributedDataParallelConfig: + """Configuration for DistributedDataParallel.""" + + grad_reduce_in_fp32: bool = False + """If true, reduce grads in fp32.""" + + overlap_grad_reduce: bool = False + """If true, overlap grad all-reduce / reduce-scatter with backward compute.""" + + overlap_param_gather: bool = False + """If true, overlap param all-gather with forward compute.""" + + align_param_gather: bool = False + """If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each + PP stage will independently launch as needed. + """ + + use_distributed_optimizer: bool = False + """If true, issue reduce-scatter collectives to aggregate gradients and clean up + originally allocated model parameters, otherwise issue all-reduce collectives. + """ + + num_distributed_optimizer_instances: int = 1 + """Sets the factor by which the DP domain is sharded to have the partial DistOpt + enabled. Defaults to 1, which means DistOpt is across entire DP domain. + """ + + check_for_nan_in_grad: bool = False + """If true, check for NaNs and Infs in gradients _before_ communication collective.""" + + check_for_large_grads: bool = False + """If true, check for unexpectedly large gradients _before_ communication collective.""" + + bucket_size: Optional[int] = None + """Maximum number of parameters in each bucket. If unspecified, MCore uses a default + value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger + buckets to ensure collectives do not become latency-bound).""" + + pad_buckets_for_high_nccl_busbw: bool = False + """If true, make sure the bucket size is divisible by a large power of 2 (2^16) to + ensure NCCL collectives have high bus bandwidth at large DP counts, since NCCL + message size (which for ring algorithms is bucket_size / dp_size) apparently needs + to be divisible by a power of 2 for high busbw.""" + + average_in_collective: bool = False + """If true, compute average in collective directly, as opposed to dividing by the + dp_size first and then computing sum in the collective.""" + + fp8_param_gather: bool = False + """If true, keep the compute param in fp8 (do not use any other intermediate dtype) and + perform the param all-gather in fp8.""" diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py index 916e4f3ecb..4b2b2bb359 100644 --- a/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/core/distributed/finalize_model_grads.py @@ -1,62 +1,166 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from typing import List +from typing import List, Optional, Union import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +try: + from torch.distributed._tensor import DTensor, distribute_tensor + + HAVE_DTENSOR = True +except ImportError: + HAVE_DTENSOR = False + from .. import parallel_state +from ..transformer.moe.moe_utils import get_updated_expert_bias from ..transformer.transformer_config import TransformerConfig from ..utils import get_attr_wrapped_model, get_model_config +def _unshard_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor: + """ + Unshards the input tensor if it is a DTensor and otherwise returns the + tensor unmodified. + + Args: + tensor (Union[torch.Tensor, DTensor]): The tensor to potentially unshard. + + Returns: + An unsharded version of the input tensor if it is a DTensor, or the + input tensor unmodified if it is not a DTensor. + """ + if HAVE_DTENSOR and isinstance(tensor, DTensor): + unsharded_tensor = tensor.full_tensor() + for k, v in vars(tensor).items(): + setattr(unsharded_tensor, k, v) + return unsharded_tensor + return tensor + + +def _reshard_if_dtensor( + tensor_to_shard: torch.Tensor, reference_tensor: Union[torch.Tensor, "DTensor"] +) -> Union[torch.Tensor, "DTensor"]: + """ + Reshards the input tensor to match the sharding configuration of the + reference tensor if the reference tensor is a DTensor. Otherwise, returns + the reference tensor unmodified. + + Args: + tensor_to_shard (torch.Tensor): The tensor to be potentially sharded. + reference_tensor (Union[torch.Tensor, DTensor]): The reference tensor + for the sharding configuration. + + Returns: + Union[torch.Tensor, DTensor]: The sharded tensor matching the reference tensor's + configuration, or the reference tensor itself if it is not a DTensor. + """ + if HAVE_DTENSOR and isinstance(reference_tensor, DTensor): + sharded_tensor = distribute_tensor( + tensor_to_shard, + device_mesh=reference_tensor.device_mesh, + placements=reference_tensor.placements, + ) + for k, v in vars(reference_tensor).items(): + setattr(sharded_tensor, k, v) + return sharded_tensor + return reference_tensor + + +def _allreduce_conditional_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): + """ + All-reduce conditional embedding grads. + + Reduce grads across all the pp stages to ensure that parameters of the conditional embedders + (e.g., timestep embedder, FPS embedder, label embedder) stay in sync. + This is for the models with replicated embedders on each PP / VPP rank, like diffusion models. + """ + + if parallel_state.get_pipeline_model_parallel_world_size() > 1 and getattr( + config, "has_cond_embedder", False + ): + grads_dict = {} + for model_chunk in model: + for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')(): + if param.requires_grad and getattr(param, 'pipeline_parallel', False): + grad = param.main_grad + if name in grads_dict: + # Add all the virtual PP rank's gradients to + # the first local virtual PP rank. + grads_dict[name][0].add_(grad) + # Append to the end for later update after cross-rank reduce. + grads_dict[name].append(grad) + else: + grads_dict[name] = [grad] + if grads_dict: + # All-reduce the gradient on the first VPP rank. + grads = [param_grad[0] for _, param_grad in grads_dict.items()] + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce( + coalesced, group=parallel_state.get_pipeline_model_parallel_group() + ) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + # Update the gradients on other VPP ranks. + for grads in grads_dict.values(): + for grad in grads[1:]: + grad.copy_(grads[0]) + + def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): """ All-reduce word embedding grads. Reduce grads across first and last stages to ensure that word_embeddings parameters stay in - sync. This should only run for models that support pipelined model parallelism (BERT and GPT). + sync. """ if ( parallel_state.is_rank_in_embedding_group(ignore_virtual=True) - and parallel_state.get_pipeline_model_parallel_world_size() > 1 + and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1 ): if parallel_state.is_pipeline_first_stage(ignore_virtual=True): model_module = model[0] elif parallel_state.is_pipeline_last_stage(ignore_virtual=True): model_module = model[-1] - else: # We do not support the interleaved schedule for T5 yet. + else: # We do not support an interleaved schedule for models with encoders yet. model_module = model[0] - # Look for module with 'pre_process' attribute to get around the fact that DDP and - # other wrapper classes inherit from non-core MegatronModule that has - # 'share_embeddings_and_output_weights' and 'shared_embedding_or_output_weight' - # attributes already, causing get_attr_wrapped_model() to not unwrap anything here. - # TODO: Clean this up once the wrapper classes inherit from core MegatronModule. model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) if model_module.share_embeddings_and_output_weights: weight = model_module.shared_embedding_or_output_weight() - grad = weight.main_grad + grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" + orig_grad = getattr(weight, grad_attr) + grad = _unshard_if_dtensor(orig_grad) torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) + setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): """ - All-reduce position_embeddings grad across first (encoder) and split (decoder) stages to - ensure that position embeddings parameters stay in sync. This should only run for T5 models - with pipeline parallelism. + All-reduce position_embeddings grad across encoder and decoder stages to ensure that position + embeddings parameters stay in sync. """ if ( parallel_state.is_rank_in_position_embedding_group() - and parallel_state.get_pipeline_model_parallel_world_size() > 1 - and config.pipeline_model_parallel_split_rank is not None + and torch.distributed.get_world_size(parallel_state.get_position_embedding_group()) > 1 ): - model_module = model[0] - grad = get_attr_wrapped_model( - model_module, 'language_model.embedding.position_embeddings.weight.main_grad' - ) + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + model_module = model[0] + elif parallel_state.is_pipeline_last_stage(ignore_virtual=True): + model_module = model[-1] + else: # We do not support an interleaved schedule for models with encoders yet. + model_module = model[0] + + model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) + assert hasattr(model_module, 'position_embeddings') + weight = model_module.position_embeddings.weight + grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" + orig_grad = getattr(weight, grad_attr) + grad = _unshard_if_dtensor(orig_grad) torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group()) + setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): @@ -74,50 +178,70 @@ def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: Transformer # All-reduce layernorm parameters across model parallel nodes # when sequence parallelism is used - if parallel_state.get_tensor_model_parallel_world_size() > 1 and config.sequence_parallel: + if parallel_state.get_tensor_model_parallel_world_size() > 1 and ( + config.sequence_parallel or config.qk_layernorm + ): + params = [] grads = [] for model_chunk in model: - for param in get_attr_wrapped_model(model_chunk, 'parameters')(): - if getattr(param, 'sequence_parallel', False): - grad = param.main_grad + for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')(): + if param.requires_grad and ( + getattr(param, 'sequence_parallel', False) + or 'q_layernorm' in name + or 'k_layernorm' in name + ): + params.append(param) + grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad" + grad = getattr(param, grad_attr) + grad = _unshard_if_dtensor(grad) grads.append(grad.data) - coalesced = _flatten_dense_tensors(grads) - torch.distributed.all_reduce( - coalesced, group=parallel_state.get_tensor_model_parallel_group() - ) - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) + if grads: + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce( + coalesced, group=parallel_state.get_tensor_model_parallel_group() + ) + for param, buf, synced in zip( + params, grads, _unflatten_dense_tensors(coalesced, grads) + ): + buf.copy_(synced) + grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad" + orig_grad = getattr(param, grad_attr) + setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad)) -def _allreduce_expert_grads(model: List[torch.nn.Module], config: TransformerConfig): +def _update_router_expert_bias(model: List[torch.nn.Module], config: TransformerConfig): """ - All-reduce expert grads (for expert parallelism). + Update the expert bias of the router for a global batch. + This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks """ + tokens_per_expert_list = [] + expert_bias_list = [] + for model_chunk in model: + for module in get_attr_wrapped_model(model_chunk, 'modules')(): + if hasattr(module, 'expert_bias'): + tokens_per_expert_list.append(module.local_tokens_per_expert) + expert_bias_list.append(module.expert_bias) + # For hybrid models with both MoE and Dense layers, this list can be empty. + if len(expert_bias_list) == 0: + return + stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0) + stacked_expert_bias = torch.stack(expert_bias_list, dim=0) + stacked_updated_expert_bias = get_updated_expert_bias( + stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate + ) - # All-reduce switchmlp parameters across data modulo expert parallel nodes - if ( - config.expert_model_parallel_size > 1 - and config.expert_model_parallel_size < parallel_state.get_data_parallel_world_size() + for tokens_per_expert, expert_bias, updated_expert_bias in zip( + tokens_per_expert_list, expert_bias_list, stacked_updated_expert_bias ): - grads = [] - for model_chunk in model: - for param in get_attr_wrapped_model(model_chunk, 'parameters')(): - if not getattr(param, 'allreduce', True): - grad = param.main_grad - grads.append(grad.data) - coalesced = _flatten_dense_tensors(grads) - torch.distributed.all_reduce( - coalesced, group=parallel_state.get_data_modulo_expert_parallel_group() - ) - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) + tokens_per_expert.zero_() + expert_bias.copy_(updated_expert_bias) -def finalize_model_grads(model: List[torch.nn.Module]): +def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None): """ All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, - embedding grads across first and last pipeline stages (if not tied), and expert grads - for expert parallelism. + embedding grads across first and last pipeline stages (if not tied), + scale gradients by `num_tokens`. """ config = get_model_config(model[0]) @@ -130,6 +254,15 @@ def finalize_model_grads(model: List[torch.nn.Module]): if config.timers is not None: config.timers('all-grads-sync').stop() + # All-reduce t_embedder grads (for pp & vpp of DiT). + if config.timers is not None: + config.timers('conditional-embedder-grads-all-reduce', log_level=1).start( + barrier=config.barrier_with_L1_time + ) + _allreduce_conditional_embedding_grads(model, config) + if config.timers is not None: + config.timers('conditional-embedder-grads-all-reduce').stop() + # All-reduce layer-norm grads (for sequence parallelism). if config.timers is not None: config.timers('layernorm-grads-all-reduce', log_level=1).start( @@ -148,11 +281,35 @@ def finalize_model_grads(model: List[torch.nn.Module]): if config.timers is not None: config.timers('embedding-grads-all-reduce').stop() - # All-reduce expert grads (for expert parallelism). - if config.timers is not None: - config.timers('expert-grads-all-reduce', log_level=1).start( - barrier=config.barrier_with_L1_time - ) - _allreduce_expert_grads(model, config) - if config.timers is not None: - config.timers('expert-grads-all-reduce').stop() + if config.moe_router_enable_expert_bias: + _update_router_expert_bias(model, config) + + # normalize gradients for per-token loss normalization. + # if we are using by the number of tokens, then we use that as a divisor. this number + # will be the total number of non-padded tokens in the global batch. + if num_tokens is not None: + + # the number of tokens is only present on the last stage, so broadcast it + # to the other ranks in the pipeline parallel group. + last_rank = parallel_state.get_pipeline_model_parallel_last_rank() + pp_group = parallel_state.get_pipeline_model_parallel_group() + + if not isinstance(last_rank, list): + assert not isinstance(last_rank, list) + last_rank = [last_rank] + assert not isinstance(pp_group, list) + pp_group = [pp_group] + + # need to do a broadcast for every pp group, even though num_tokens should be the same. + num_tokens_list = [] + for lr, group in zip(last_rank, pp_group): + torch.distributed.broadcast(num_tokens, src=lr, group=group) + num_tokens_list.append(torch.clone(num_tokens)) + assert all(x.item() == num_tokens_list[0] for x in num_tokens_list) + + # all-reduce across DP ranks. + torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group()) + for model_chunk in model: + if num_tokens > 0: + scaling = 1.0 / num_tokens + model_chunk.scale_gradients(scaling) diff --git a/megatron/core/distributed/grad_buffer.py b/megatron/core/distributed/grad_buffer.py deleted file mode 100644 index 223c2bef18..0000000000 --- a/megatron/core/distributed/grad_buffer.py +++ /dev/null @@ -1,336 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from logging import getLogger -from typing import Dict, List - -import torch - -from .. import parallel_state - -logger = getLogger(__name__) - - -def shard_buffer(buffer: torch.Tensor): - """ - Shard buffer into dp_size chunks of equal size. - """ - data_parallel_world_size = parallel_state.get_data_parallel_world_size( - with_context_parallel=True - ) - assert buffer.numel() % data_parallel_world_size == 0 - shard_size = buffer.numel() // data_parallel_world_size - sharded_buffer = [ - buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size) - ] - return sharded_buffer - - -class Bucket: - """ - Bucket to keep track of a subset of the model's gradients. Provides functionality to register - when params in the bucket have grads ready to be synced; an asynchronous communication call - is automatically launched when _all_ params in the bucket have grads ready. - - Arguments: - params: List of parameters whose gradients are collated in this bucket. - data: View in larger GradBuffer that this bucket is responsible for. - offset: Offset of this bucket's view in the larger GradBuffer. - data_parallel_group: Data-parallel process group. - overlap_grad_reduce: If true, overlap communication with backprop computation by - breaking up grads into buckets. If false, single synchronous communication call - is used instead. - use_distributed_optimizer: If true, issue reduce-scatter communication calls as part - of distributed optimizer. If false, issue all-reduce communication calls. - """ - - def __init__( - self, - params: List[torch.nn.Parameter], - data: torch.Tensor, - offset: int, - data_parallel_group: torch.distributed.ProcessGroup, - overlap_grad_reduce: bool, - use_distributed_optimizer: bool, - ): - # State for bookkeeping: params is the set of parameters this bucket is - # responsible for, params_with_grad is the set of parameters with grads - # available. When overlap_grad_reduce is True, communication (all-reduce - # or reduce-scatter) is issued when params_with_grad equals params. - self.params_list = params - self.params = set(params) - self.params_with_grad = set() - self.data = data - # The distributed optimizer needs to keep track of this bucket's offset - # within the full grad_buffer. - self.offset = offset - self.data_parallel_group = data_parallel_group - self.overlap_grad_reduce = overlap_grad_reduce - self.use_distributed_optimizer = use_distributed_optimizer - - self.data_parallel_world_size = torch.distributed.get_world_size(group=data_parallel_group) - self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group) - - self.reset() - - def reset(self): - """ - Reset metadata in bucket in preparation for the next iteration of training. - """ - self.params_with_grad = set() - self.communication_handle = None - self.communication_issued = False - - def start_grad_sync(self): - """ - Initiates grad sync (all-reduce or reduce-scatter) communication operation - for this bucket. - - When overlap_grad_reduce is set to True, dispatches an asynchronous - communication call. When overlap_grad_reduce is set to False, makes - synchronous call. - """ - assert ( - self.communication_handle is None and not self.communication_issued - ), 'Should not have multiple communication calls in flight at once' - - self.data /= self.data_parallel_world_size - # Use async_op only when overlap_grad_reduce is True. - if self.use_distributed_optimizer: - local_data_view = shard_buffer(self.data)[self.data_parallel_rank] - self.communication_handle = torch.distributed._reduce_scatter_base( - local_data_view, - self.data, - group=self.data_parallel_group, - async_op=self.overlap_grad_reduce, - ) - else: - self.communication_handle = torch.distributed.all_reduce( - self.data, group=self.data_parallel_group, async_op=self.overlap_grad_reduce - ) - self.communication_issued = True - - def finish_grad_sync(self): - """ - Finishes grad sync (all-reduce or reduce-scatter) communication operation - for this bucket. - - When overlap_grad_reduce is set to True, waits for asynchronous communication - call to complete. When overlap_grad_reduce is set to False, makes synchronous call. - """ - # If overlap_grad_reduce is False, start (and finish) synchronous communication call here. - if not self.overlap_grad_reduce: - self.start_grad_sync() - return - assert self.communication_handle is not None and self.communication_issued, ( - f'Communication call has not been issued for this bucket ' - f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)' - ) - self.communication_handle.wait() - - def register_grad_ready(self, param: torch.nn.Parameter): - """ - Registers grads for the passed-in param to be "ready" for grad sync. - - When the number of microbatches is greater than 1, we only want to register - grads as ready when processing the last microbatch and overlap_grad_reduce is True. - """ - assert param in self.params, 'Param is not in the bucket' - assert param not in self.params_with_grad, 'Cannot set grad twice' - assert ( - self.overlap_grad_reduce - ), 'register_grad_ready() should be called only when overlapping grad reduce' - self.params_with_grad.add(param) - # If all params in bucket have grads available, issue communication call. - if len(self.params_with_grad) == len(self.params): - self.start_grad_sync() - - -class GradBuffer: - """ - Groups gradients into a contiguous buffer, and then breaks the buffer into buckets with - roughly `bucket_size` parameters each. - - Arguments: - numel: True number of elements. - numel_padded: Number of elements in underlying tensor. - dtype: Type of underlying tensor. - params: List of parameters whose gradients are collated in the underlying tensor. - data_parallel_group: Data-parallel process group. - bucket_size: The rough size of each bucket in terms of number of parameters. - param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes). - overlap_grad_reduce: If true, overlap communication with backprop computation by - breaking up grads into buckets. If false, single synchronous communication call - is used instead. - use_distributed_optimizer: If true, issue reduce-scatter communication calls as part - of distributed optimizer. If false, issue all-reduce communication calls. - """ - - def __init__( - self, - numel: int, - numel_padded: int, - dtype: torch.dtype, - params: List[torch.nn.Parameter], - data_parallel_group: torch.distributed.ProcessGroup, - bucket_size: int, - param_to_name: Dict[torch.nn.Parameter, str], - overlap_grad_reduce: bool, - use_distributed_optimizer: bool, - ): - self.numel = numel - self.numel_padded = numel_padded - self.dtype = dtype - self.data = torch.zeros( - self.numel_padded, - dtype=self.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - - self.buckets = [] - self.param_to_bucket = {} - self.param_to_bucket_index = {} - self.overlap_grad_reduce = overlap_grad_reduce - self.use_distributed_optimizer = use_distributed_optimizer - - self.is_last_microbatch = True - - # Check that params are unique. - unique_params = set() - for param in params: - assert param not in unique_params - unique_params.add(param) - del unique_params - - # Helper function to create new bucket, add it to list of buckets, and - # also update param->bucket mapping. - def _set_bucket( - bucket_params: List[torch.nn.Parameter], data_start_index: int, data_end_index: int - ): - - # Get appropriate view into global GradBuffer. - bucket_data = self._get( - torch.Size([data_end_index - data_start_index]), data_start_index - ) - bucket = Bucket( - bucket_params, - bucket_data, - data_start_index, - data_parallel_group, - self.overlap_grad_reduce, - self.use_distributed_optimizer, - ) - self.buckets.append(bucket) - for bucket_param in bucket_params: - assert bucket_param not in self.param_to_bucket - assert bucket_param not in self.param_to_bucket_index - self.param_to_bucket[bucket_param] = bucket - self.param_to_bucket_index[bucket_param] = len(self.buckets) - 1 - - # Map the grads to the buffer and bucket them. - data_start_index = 0 - bucket_data_start_index = data_start_index - bucket_params = set() - - # Iterate through parameters in reverse order to roughly follow backprop order. - for param in params[::-1]: - # Skip parameters that don't require gradients. - if not param.requires_grad: - continue - this_numel = param.data.nelement() - data_end_index = data_start_index + this_numel - param.main_grad = self._get(param.data.shape, data_start_index) - bucket_params.add(param) - - # If we have enough elements already, form a new buffer. - # If bucket_size is None, accumulate everything into a single bucket. - if bucket_size is not None: - if (data_end_index - bucket_data_start_index) >= bucket_size: - _set_bucket(bucket_params, bucket_data_start_index, data_end_index) - bucket_data_start_index = data_end_index - bucket_params = set() - data_start_index = data_end_index - - # Add remaining params to a new bucket. - if len(bucket_params) > 0: - _set_bucket(bucket_params, bucket_data_start_index, data_end_index) - - if not overlap_grad_reduce: - assert len(bucket_params) == len( - params - ), 'All params should be in one bucket when overlap_grad_reduce is False' - - # Print buckets for all PP stages. - if ( - parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0 - and parallel_state.get_tensor_model_parallel_rank() == 0 - ): - logger.info( - f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}' - ) - for index, bucket in enumerate(self.buckets): - numel = 0 - for param in bucket.params: - numel += param.data.nelement() - logger.info(f'Params for bucket {index+1} ({numel} elements):') - for param in bucket.params: - logger.info(f' {param_to_name[param]}') - - def _get(self, shape: torch.Size, start_index: int) -> torch.Tensor: - """ - Return a tensor with the input `shape` as a view into the 1-D data starting at - `start_index`. - """ - end_index = start_index + shape.numel() - assert end_index <= self.numel, 'Requested tensor is out of buffer range' - buffer_tensor = self.data[start_index:end_index] - buffer_tensor = buffer_tensor.view(shape) - return buffer_tensor - - def reset(self): - """ - Zero out the underlying buffer and reset all buckets in preparation for the next - iteration of training. - """ - self.data.zero_() - for bucket in self.buckets: - bucket.reset() - self.is_last_microbatch = True - - def start_grad_sync(self): - """ - Initiates grad sync (all-reduce or reduce-scatter) communication operations - for all buckets in the grad buffer. - - When overlap_grad_reduce is set to True, dispatches asynchronous communication - calls. When overlap_grad_reduce is set to False, calls synchronous - communication ops. - """ - for bucket in self.buckets: - bucket.start_grad_sync() - - def finish_grad_sync(self): - """ - Finishes grad sync (all-reduce or reduce-scatter) communication operations - for all buckets in the grad buffer. - - When overlap_grad_reduce is set to True, waits for asynchronous communication - calls to complete. When overlap_grad_reduce is set to False, calls synchronous - communication ops. - """ - for bucket in self.buckets: - bucket.finish_grad_sync() - - def register_grad_ready(self, param: torch.nn.Parameter): - """ - Registers grads for the passed-in param to be "ready" for grad sync. - - When the number of microbatches is greater than 1, we only want to register - grads as ready when processing the last microbatch and overlap_grad_reduce is True. - """ - assert ( - self.overlap_grad_reduce - ), 'register_grad_ready() should only be called when overlap_grad_reduce is True' - if self.is_last_microbatch: - bucket = self.param_to_bucket[param] - bucket.register_grad_ready(param) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py new file mode 100644 index 0000000000..3d04c18790 --- /dev/null +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -0,0 +1,882 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import logging +import math +import warnings +from contextlib import nullcontext +from enum import Enum +from functools import partial +from typing import Dict, List, Optional + +import torch +from torch.distributed import _coalescing_manager + +from megatron.core.rerun_state_machine import get_rerun_state_machine + +from ..utils import is_float8tensor, is_torch_min_version, log_on_each_pipeline_stage +from .distributed_data_parallel_config import DistributedDataParallelConfig + +logger = logging.getLogger(__name__) + + +if is_torch_min_version("1.13.0"): + dist_all_gather_func = torch.distributed.all_gather_into_tensor + dist_reduce_scatter_func = torch.distributed.reduce_scatter_tensor +else: + dist_all_gather_func = torch.distributed._all_gather_base + dist_reduce_scatter_func = torch.distributed._reduce_scatter_base + + +class BufferType(Enum): + """ + Enumeration for buffer type. + """ + + PARAM = 1 + GRAD = 2 + + +def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int): + """ + Shard buffer into data_parallel_world_size chunks of equal size. + """ + assert buffer.numel() % data_parallel_world_size == 0 + shard_size = buffer.numel() // data_parallel_world_size + sharded_buffer = [ + buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size) + ] + return sharded_buffer + + +class _ParamAndGradBucket: + """ + Bucket to keep track of a subset of the model's parameters and gradients. + + Args: + params: List of parameters whose gradients are collated in this bucket. + param_data: View in _ParamAndGradBuffer.param_data that this bucket is responsible for. + grad_data: View in _ParamAndGradBuffer.grad_data that this bucket is responsible for. + offset: Offset of this bucket's view in the larger _ParamAndGradBuffer. + numel_unpadded: Number of unpadded elements in bucket. + gradient_scaling_factor: This factor is utilized to scale gradients prior to their + communication. Its application is twofold: it facilitates the averaging of gradients + and the scaling of gradients in the context of the Mixture of Experts (MoE) model. + bucket_id: Index of bucket in buffer. + """ + + def __init__( + self, + params: List[torch.nn.Parameter], + param_data: Optional[torch.Tensor], + grad_data: torch.Tensor, + offset: int, + numel_unpadded: int, + gradient_scaling_factor: float, + bucket_id: int, + ): + self.params_list = params + self.params = set(params) + # Make sure there are no duplicate params. + assert len(self.params_list) == len(self.params) + self.param_data = param_data + self.grad_data = grad_data + # The distributed optimizer needs to keep track of this bucket's offset + # within the full grad_buffer. + self.offset = offset + self.numel_unpadded = numel_unpadded + self.gradient_scaling_factor = gradient_scaling_factor + self.bucket_id = bucket_id + + +class _ParamAndGradBucketGroup: + """ + Put multiple buckets into a group so that their communications can be aggregated together. + Provides functionality to register when params in the bucket group have grads ready to be + synced; an asynchronous communication call is automatically launched when _all_ params in + the bucket group have grads ready. + + Args: + buckets: A list of buckets. + ddp_config: DistributedDataParallel config object. + collective_group: intra_distributed_optimizer_instance_group if using distributed + optimizer, data_parallel_group if not. + collective_group_size: World size using the intra data-parallel group. + """ + + def __init__( + self, + buckets: List[_ParamAndGradBucket], + ddp_config: DistributedDataParallelConfig, + collective_group: torch.distributed.ProcessGroup, + collective_group_size: int, + ): + self.buckets = buckets + self.ddp_config = ddp_config + + if self.ddp_config.use_distributed_optimizer: + self.intra_distributed_optimizer_instance_group = collective_group + self.intra_distributed_optimizer_instance_size = collective_group_size + self.intra_distributed_optimizer_instance_rank = torch.distributed.get_rank( + group=collective_group + ) + else: + self.data_parallel_group = collective_group + + # State for bookkeeping: params is the set of parameters this bucket group is + # responsible for, params_with_grad is the set of parameters with grads + # available. When overlap_grad_reduce is True, communication (all-reduce + # or reduce-scatter) is issued when params_with_grad equals params. + self.param_to_bucket = {} + self.params = set() + for bucket in self.buckets: + for param in bucket.params_list: + self.param_to_bucket[param] = bucket + self.params.add(param) + + self.next_param_gather_bucket_group = None + + if self.ddp_config.num_distributed_optimizer_instances > 1: + self.inter_distributed_optimizer_instance_group = None + self.communication_stream = None + + self.reset() + self.param_gather_handle = None + self.param_gather_dispatched = False + self.grad_reduce_handle = None + + def reset(self): + """ + Reset metadata in bucket group in preparation for the next iteration of training. + """ + self.params_with_grad = set() + self.is_last_microbatch = True + + def check_grads(self, check_for_nan_or_inf, check_for_large): + """ + Make sure norm of grads in bucket are not NaN prior to data-parallel + all-reduce / reduce-scatter. + """ + rerun_state_machine = get_rerun_state_machine() + for i in range(len(self.buckets)): + grad_norm = self.buckets[i].grad_data.norm(p=2) + # check for NaN, Inf and unexpectedly large grads + if check_for_nan_or_inf: + rerun_state_machine.validate_result( + result=grad_norm, + rejection_func=torch.isnan, + message=f"found NaN in local grad norm for bucket #{i} " + f"in backward pass before data-parallel communication collective", + tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward + fatal=True, + ) + rerun_state_machine.validate_result( + result=grad_norm, + rejection_func=torch.isinf, + message=f"found Inf in local grad norm for bucket #{i} " + f"in backward pass before data-parallel communication collective", + tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward + fatal=True, + ) + if check_for_large: + rerun_state_machine.validate_result( + result=grad_norm, + rejection_func=partial( + rerun_state_machine.is_unexpectedly_large, threshold=10, context="grads" + ), + message=f"found unexpected large grads in bucket #{i} " + f"in backward pass before data-parallel communication collective", + tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward + fatal=False, + ) + + def start_param_sync(self, force_sync: bool = False): + """ + Initiates all necessary param all-gathers for this bucket. + + When ddp_config.overlap_param_gather is set to True, dispatches an asynchronous + communication call (unless force_sync is True). When ddp_config.overlap_param_gather + is set to False, makes synchronous call. + + Args: + force_sync (bool, optional): force synchronous collective regardless of + other settings if true. + """ + assert self.ddp_config.use_distributed_optimizer + + if force_sync: + if self.param_gather_handle is not None: + self.param_gather_handle.wait() + self.param_gather_handle = None + return + else: + assert self.param_gather_handle is None + + async_op = self.ddp_config.overlap_param_gather and not force_sync + # Coalesce communication kernels across buckets in the bucket group. + with _coalescing_manager( + self.intra_distributed_optimizer_instance_group, async_ops=async_op + ) as cm: + for bucket in self.buckets: + local_data_view = shard_buffer( + bucket.param_data, self.intra_distributed_optimizer_instance_size + )[self.intra_distributed_optimizer_instance_rank] + dist_all_gather_func( + bucket.param_data, + local_data_view, + group=self.intra_distributed_optimizer_instance_group, + async_op=async_op, + ) + if async_op: + self.param_gather_handle = cm + else: + # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used, + # `cm` is not None, which is different from when `_coalescing_manager` is not used in + # which case the torch.distributed._all_gather_base() will return None. In order to + # maintain consistency with prior code, we need to manually set communication handle to + # None. + self.param_gather_handle = None + self.param_gather_dispatched = True + + def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): + """ + Finishes param sync communication operation for this bucket. Dispatches + next bucket's param sync if available, unless skip_next_bucket_dispatch + is True. + + When ddp_config.overlap_param_gather is set to True, waits for asynchronous + communication call to complete (and dispatches one if one is not already + outstanding). Throws assertion error if ddp_config.overlap_param_gather is set to + False. + + Args: + skip_next_bucket_dispatch (bool, optional): if true, dispatch next + bucket's communication if available. + """ + assert self.ddp_config.use_distributed_optimizer + assert self.ddp_config.overlap_param_gather + + # If current bucket's param AG has not been dispatched, dispatch it now (e.g., first + # AG bucket in first model chunk if ddp_config.align_param_gather is False). + if not self.param_gather_dispatched: + self.start_param_sync() + + if self.param_gather_handle is not None: + self.param_gather_handle.wait() + self.param_gather_handle = None + # Dispatch next bucket's asynchronous param AG only if it has not been dispatched yet. + if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch: + if self.next_param_gather_bucket_group.param_gather_dispatched: + warnings.warn( + "The next bucket's parameter all-gather operation has already been " + "dispatched. This may be caused by a mismatch between the order of " + "parameter registration and forward pass execution, which will " + "hurt the communication-computation overlap performance." + ) + else: + self.next_param_gather_bucket_group.start_param_sync() + + def start_grad_sync(self): + """ + Initiates grad sync (all-reduce or reduce-scatter) communication operations + for all buckets in the bucket group. + + When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous + communication call. When ddp_config.overlap_grad_reduce is set to False, makes + synchronous call. + """ + assert ( + self.grad_reduce_handle is None + ), 'Should not have multiple communication calls outstanding at once' + + if self.ddp_config.check_for_nan_in_grad or self.ddp_config.check_for_large_grads: + self.check_grads( + check_for_nan_or_inf=self.ddp_config.check_for_nan_in_grad, + check_for_large=self.ddp_config.check_for_large_grads, + ) + + # gradient_scaling_factor already takes into account whether we are computing + # an average or sum in the data-parallel collective. + for bucket in self.buckets: + if bucket.gradient_scaling_factor != 1.0: + bucket.grad_data *= bucket.gradient_scaling_factor + + # Decide reduce_op. + reduce_op = torch.distributed.ReduceOp.SUM + if self.ddp_config.average_in_collective: + reduce_op = torch.distributed.ReduceOp.AVG + + # We use the following stream synchronization for the gradient reduction + # within and across DistOpt instances. + + # Compute Stream: -------------Gradient compute------------------- + # Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)------- + # NCCL Stream: -------RS------ -------AR------ + + # Use async communications only when overlap_grad_reduce is True. + async_op = ( + self.ddp_config.overlap_grad_reduce + and self.ddp_config.num_distributed_optimizer_instances == 1 + ) + if ( + self.ddp_config.num_distributed_optimizer_instances > 1 + and self.ddp_config.overlap_grad_reduce + ): + # Assign a communication stream if we have multiple DistOpt instances and we + # need to overlap communication. + stream_context = torch.cuda.stream(self.communication_stream) + + # The RS/AR communication stream needs to wait for the default stream + # to complete its gradient computation before launching the next + # gradient reduction collective. + self.communication_stream.wait_stream(torch.cuda.default_stream()) + else: + stream_context = nullcontext() + + if self.ddp_config.use_distributed_optimizer: + communication_group = self.intra_distributed_optimizer_instance_group + else: + communication_group = self.data_parallel_group + + # Coalesce communication kernels across buckets in the bucket group. + with stream_context, _coalescing_manager(communication_group, async_ops=async_op) as cm: + for bucket in self.buckets: + if self.ddp_config.use_distributed_optimizer: + local_data_view = shard_buffer( + bucket.grad_data, self.intra_distributed_optimizer_instance_size + )[self.intra_distributed_optimizer_instance_rank] + dist_reduce_scatter_func( + local_data_view, + bucket.grad_data, + op=reduce_op, + group=communication_group, + async_op=async_op, + ) + else: + torch.distributed.all_reduce( + bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op + ) + + # With multiple DistOpt instances, we need to all-reduce across instances. + if ( + self.ddp_config.use_distributed_optimizer + and self.ddp_config.num_distributed_optimizer_instances > 1 + ): + + # Create a new coalescing manager for the inter-instance all-reduce. + with stream_context, _coalescing_manager( + self.inter_distributed_optimizer_instance_group, async_ops=async_op + ) as cm: + for bucket in self.buckets: + local_data_view = shard_buffer( + bucket.grad_data, self.intra_distributed_optimizer_instance_size + )[self.intra_distributed_optimizer_instance_rank] + + torch.distributed.all_reduce( + local_data_view, + op=reduce_op, + group=self.inter_distributed_optimizer_instance_group, + async_op=async_op, + ) + + if async_op: + self.grad_reduce_handle = cm + else: + # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used, + # `cm` is not None, which is different from when `_coalescing_manager` is not used in + # which case the torch.distributed._reduce_scatter_base() will return None. In order to + # maintain consistency with prior code, we need to manually set communication handle to + # None. + self.grad_reduce_handle = None + + def finish_grad_sync(self): + """ + Finishes grad sync (all-reduce or reduce-scatter) communication operations + for all buckets in the bucket group. + + When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous + communication call to complete. When ddp_config.overlap_grad_reduce is set to False, + makes synchronous call. + """ + self.param_gather_dispatched = False + # If overlap_grad_reduce is False, start (and finish) synchronous communication call here. + if not self.ddp_config.overlap_grad_reduce: + self.start_grad_sync() + return + # When using multiple DistOpt instances, we don't need to sync here as we launch + # communications on a separate communication stream. + if self.ddp_config.num_distributed_optimizer_instances > 1: + torch.cuda.default_stream().wait_stream(self.communication_stream) + return + assert self.grad_reduce_handle is not None, ( + f'Communication call has not been issued for this bucket ' + f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)' + ) + self.grad_reduce_handle.wait() + self.grad_reduce_handle = None + + def register_grad_ready(self, param: torch.nn.Parameter): + """ + Registers grads for the passed-in param to be "ready" for grad sync. + + When the number of microbatches is greater than 1, we only want to register + grads as ready when processing the last microbatch and ddp_config.overlap_grad_reduce + is True. + """ + assert ( + self.ddp_config.overlap_grad_reduce + ), 'register_grad_ready() should only be called when overlap_grad_reduce is True' + if self.is_last_microbatch: + assert param in self.param_to_bucket, 'Param is not in the bucket group' + assert param not in self.params_with_grad, 'Cannot set grad twice' + self.params_with_grad.add(param) + # If all params in bucket group have grads available, issue communication call. + if len(self.params_with_grad) == len(self.params): + self.start_grad_sync() + + +class _ParamAndGradBuffer: + """ + Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into + buckets with roughly `bucket_size` parameters each. + + Args: + ddp_config: DistributedDataParallel config object. + param_dtype: Type of param tensor. + grad_dtype: Type of grad tensor. + params: List of parameters whose parameters and gradients are collated in the underlying + tensor. + data_parallel_group: Data-parallel process group. + bucket_size: The rough size of each bucket in terms of number of parameters. + param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes). + gradient_scaling_factor: This factor is utilized to scale gradients prior to their + communication. Its application is twofold: it facilitates the averaging of gradients + and the scaling of gradients in the context of the Mixture of Experts (MoE) model. + param_indices: The index of each param among the params with same dtype, if a param is fp8, + use its "fake" high precision dtype to determine which params have same dtype with it. + These indices are needed when loading a non-native-fp8 checkpoint in native-fp8 mode. + """ + + def __init__( + self, + ddp_config: DistributedDataParallelConfig, + param_dtype: torch.dtype, + grad_dtype: torch.dtype, + params: List[torch.nn.Parameter], + data_parallel_group: torch.distributed.ProcessGroup, + bucket_size: int, + param_to_name: Dict[torch.nn.Parameter, str], + gradient_scaling_factor: float, + param_indices: List[int], + ): + self.ddp_config = ddp_config + self.params = params + self.param_indices = param_indices + + # Check that params are unique. + unique_params = set() + for param in params: + assert param not in unique_params + unique_params.add(param) + del unique_params + + # Store attributes that will be needed later. + self.param_dtype = param_dtype + self.grad_dtype = grad_dtype + self.data_parallel_group = data_parallel_group + self.data_parallel_world_size = torch.distributed.get_world_size( + group=self.data_parallel_group + ) + self.gradient_scaling_factor = gradient_scaling_factor + + # Data structures to store underlying buckets and relevant indexing data. + self.buckets = [] + self.param_to_bucket = {} # Param -> bucket mapping. + self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer). + + def _pad(number_to_be_padded: int, divisor: int) -> int: + return int(math.ceil(number_to_be_padded / divisor) * divisor) + + def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int: + """ + Pads end index of bucket if using distributed optimizer (to ensure uniform sharding). + """ + if self.ddp_config.use_distributed_optimizer: + # Workaround for TE bug causing cuBLAS to pick an incompatible algorithm. + # This also helps cuBLAS pick more efficient algorithms for GEMMs. + # We now ensure that all buckets start at a memory address that is 256-byte + # aligned (128 values since params and grads use >= 16-bit precision). + if self.ddp_config.pad_buckets_for_high_nccl_busbw: + # Make sure the bucket size is divisible by a large power of 2 (2^16) to + # ensure NCCL collectives have high bus bandwidth at large DP counts, + # since NCCL message size (which for ring algorithms is bucket_size / + # dp_size) apparently needs to be divisible by a power of 2 for high busbw. + bucket_size_divisor = math.lcm(self.data_parallel_world_size, 128, 2**16) + else: + bucket_size_divisor = math.lcm(self.data_parallel_world_size, 128) + return _pad(bucket_end_index, bucket_size_divisor) + return bucket_end_index + + def _pad_start_of_param_if_needed(param_start_index: int) -> int: + """ + Pads start index of param if using distributed optimizer (to ensure "good" alignment). + """ + if self.ddp_config.use_distributed_optimizer: + # Ensure that params start at 128-byte aligned addresses (64 values + # since params are >= 16-bit precision). + return _pad(param_start_index, 64) + return param_start_index + + # First, figure out how many elements should be in the underlying buffer storage. + # Note that if we need to split the buffer into smaller buckets, each of these + # might need to be padded as well (if using the distributed optimizer). + param_start_index = 0 + bucket_start_index = param_start_index + bucket_params = set() + self.bucket_indices = [] + per_bucket_numel_unpadded = [] + bucket_id = 0 + + def _update_bucket_metadata(param_end_index: int) -> int: + """ + Record metadata for the bucket starting at bucket_start_index and ending with the + passed-in param_end_index. Returns the bucket's end_index. + """ + nonlocal bucket_start_index, bucket_params, bucket_id + per_bucket_numel_unpadded.append(param_end_index - bucket_start_index) + bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index) + + # Record metadata of new bucket. + self.bucket_indices.append((bucket_start_index, bucket_end_index)) + bucket_start_index = bucket_end_index + + # Prepare for next bucket. + bucket_params = set() + bucket_id += 1 + + # Return the potentially padded bucket_end_index. + return bucket_end_index + + def _does_param_require_new_bucket(param): + """ + Split shared embedding parameters into separate bucket if using distributed + optimizer that makes use of reduce-scatters instead of all-reduces. + This ensures that the first and last pipeline stage partition optimizer state + for the shared embedding parameters the same way across DP replicas, allowing + the DP reduce-scatter to be before the embedding all-reduce. + """ + return ( + getattr(param, "shared_embedding", False) + and self.ddp_config.use_distributed_optimizer + ) + + for param in params[::-1]: + # Iterate through parameters in reverse order to roughly follow backprop order. + + this_numel = param.data.nelement() + param_start_index = _pad_start_of_param_if_needed(param_start_index) + + # Create bucket with collected parameters if current param needs its own bucket. + if _does_param_require_new_bucket(param): + # We are creating a bucket for the already accumulated parameters, whose params + # end at the current param_start_index. + if self.ddp_config.use_distributed_optimizer: + # Make sure new bucket is appropriately padded. + if param_start_index % self.data_parallel_world_size != 0: + param_start_index = _pad_end_of_bucket_if_needed(param_start_index) + if len(bucket_params) > 0: + bucket_end_index = _update_bucket_metadata(param_start_index) + + param_end_index = param_start_index + this_numel + self.param_index_map[param] = (param_start_index, param_end_index, bucket_id) + bucket_params.add(param) + + # If we have enough elements already or the current param is part of the shared + # embedding layer and needs a separate bucket, form a new bucket. + if ( + bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size + ) or _does_param_require_new_bucket(param): + bucket_end_index = _update_bucket_metadata(param_end_index) + param_start_index = bucket_end_index + else: + param_start_index = param_end_index + + # Add remaining params to a new bucket. + if len(bucket_params) > 0: + bucket_end_index = _update_bucket_metadata(param_end_index) + + # Next, create underlying storage for buffer (with numel elements that includes + # padding as necessary). + self.numel = bucket_end_index + self.numel_unpadded = sum(per_bucket_numel_unpadded) + assert self.numel_unpadded <= self.numel + if self.ddp_config.use_distributed_optimizer: + assert self.numel % self.data_parallel_world_size == 0 + else: + assert self.numel == self.numel_unpadded + + self.param_data = None + # Only re-map param tensors if using distributed optimizer. + if self.ddp_config.use_distributed_optimizer: + self.param_data = torch.zeros( + self.numel, + dtype=self.param_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + self.grad_data = torch.zeros( + self.numel, + dtype=self.grad_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + # Finally, map param.data and param.main_grad fields to buffers. + bucket_params = [] + bucket_start_index = 0 + cur_bucket_id = 0 + for param in params[::-1]: + param_start_index, param_end_index, bucket_id = self.param_index_map[param] + + # Assign param.data to appropriate segment of self.param_data. + if self.param_data is not None: + old_param_data = param.data + new_param_data = self._get( + param.data.shape, param_start_index, buffer_type=BufferType.PARAM + ) + if is_float8tensor(param): + param._data = new_param_data + else: + param.data = new_param_data + assert old_param_data._base is None + # Copy tensor values (from initialization or checkpoint). + param.data.detach().copy_(old_param_data) + del old_param_data + + param.main_grad = self._get( + param.data.shape, param_start_index, buffer_type=BufferType.GRAD + ) + if bucket_id != cur_bucket_id: + bucket_end_index = _pad_end_of_bucket_if_needed(param_start_index) + self.buckets.append( + self._new_bucket( + bucket_params=bucket_params, + start_index=bucket_start_index, + end_index=bucket_end_index, + numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id], + bucket_id=cur_bucket_id, + ) + ) + bucket_start_index = bucket_end_index + bucket_params = [] + assert cur_bucket_id + 1 == len(self.buckets) + assert bucket_id == cur_bucket_id + 1 + cur_bucket_id = bucket_id + bucket_params.append(param) + + # Add remaining params to a new bucket. + if len(bucket_params) > 0: + bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index) + self.buckets.append( + self._new_bucket( + bucket_params=bucket_params, + start_index=bucket_start_index, + end_index=bucket_end_index, + numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id], + bucket_id=cur_bucket_id, + ) + ) + + # Log buckets for all PP stages. + log_strs = [] + log_strs.append( + f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}' + ) + for index, bucket in enumerate(self.buckets): + numel = 0 + for param in bucket.params: + numel += param.data.nelement() + log_strs.append( + f"Params for bucket {index+1} ({numel} elements, " + f"{bucket.grad_data.nelement()} padded size):" + ) + for param in bucket.params: + log_strs.append(f'\t{param_to_name[param]}') + log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs)) + + def scale_gradients(self, scaling_factor: float) -> None: + """Scale the gradient data by `scaling_factor`.""" + self.grad_data *= scaling_factor + + def _get(self, shape: torch.Size, start_index: int, buffer_type: BufferType) -> torch.Tensor: + """ + Return a tensor with the input `shape` as a view into the 1-D data starting at + `start_index`. + """ + end_index = start_index + shape.numel() + assert end_index <= self.numel, 'Requested tensor is out of buffer range' + if buffer_type == BufferType.PARAM: + assert self.param_data is not None + buffer_tensor = self.param_data[start_index:end_index] + elif buffer_type == BufferType.GRAD: + buffer_tensor = self.grad_data[start_index:end_index] + else: + raise Exception("Illegal buffer type provided to GradBuffer._get() function") + buffer_tensor = buffer_tensor.view(shape) + return buffer_tensor + + def _new_bucket( + self, + bucket_params: List[torch.nn.Parameter], + start_index: int, + end_index: int, + numel_unpadded: int, + bucket_id: int, + ) -> _ParamAndGradBucket: + """ + Helper function that creates a new bucket. Also updates param->bucket mapping. + """ + + # Assert that indices are correctly padded (if needed), and that bucket + # position is same as originally computed. + if self.ddp_config.use_distributed_optimizer: + assert start_index % self.data_parallel_world_size == 0 + assert end_index % self.data_parallel_world_size == 0 + assert (start_index, end_index) == self.bucket_indices[bucket_id] + + # Get appropriate view into global _ParamAndGradBuffer. + bucketed_param_data = None + if self.param_data is not None: + bucketed_param_data = self._get( + torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM + ) + bucketed_grad_data = self._get( + torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.GRAD + ) + bucket = _ParamAndGradBucket( + params=bucket_params, + param_data=bucketed_param_data, + grad_data=bucketed_grad_data, + offset=start_index, + numel_unpadded=numel_unpadded, + gradient_scaling_factor=self.gradient_scaling_factor, + bucket_id=bucket_id, + ) + for bucket_param in bucket_params: + assert bucket_param not in self.param_to_bucket + self.param_to_bucket[bucket_param] = bucket + + return bucket + + def reset(self): + """ + Zero out the underlying grad_buffer. + """ + self.grad_data.zero_() + + +def partition_buckets( + buffers: List[_ParamAndGradBuffer], force_single_bucket_group: bool = False +) -> List[_ParamAndGradBucketGroup]: + """ + Automatically regroup the buckets of input buffers and return a list of bucket groups. + + In some scenarios, we need to put buckets from different buffers into a group so that their + communication can be aggregated. + + For example, when there are both fp8 weights and bf16 biases in the model and virtual + pipeline parallelism is enabled, each model chunk will have an fp8 bucket and a bf16 bucket, + which doubles the number of communication kernels, and because of the use of + CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back communications will prevent the + overlap of communication kernels with computation kernels. + + The grouping strategy is: + 1. If force_single_bucket_group is True, put all buckets across all buffers into a single + bucket group. + 2. If force_single_bucket_group is False, when there is no fp8 buffer in the input buffers, + let each bucket group have only one bucket. + 3. If force_single_bucket_group is False, when using fp8 params, merge all non-fp8 buckets + into the last fp8 bucket group. + - Since the non-fp8 parameters (typically the biases of various layers) are relatively + small, they are likely to be grouped into a single non-fp8 bucket. + - The fp8 buckets start from the end of the model, i.e., the first bucket corresponds to + the end of the model, while the last bucket corresponds to the beginning. + - If we combine the non-fp8 bucket with the first fp8 bucket, we cannot initiate the + reduce-scatter to synchronize gradients after the backward pass at the end of the model + has completed. This is because we need to wait for the non-fp8 params from the beginning + layers to obtain their gradients. + - Combining the non-fp8 bucket with the last fp8 bucket can help avoid this issue. + + Args: + buffers (list): list of input buffers. + single_bucket_group_per_buffer (bool, optional): force group all buckets in each buffer + into a single bucket group. + """ + + if len(buffers) == 0: + return [] + + dtype_to_buffer_map = {} + for buffer in buffers: + dtype = buffer.param_dtype + # Make sure that the param_dtype of any two buffers is different. + assert dtype not in dtype_to_buffer_map + dtype_to_buffer_map[dtype] = buffer + + # Case 1: Put all buckets into a single bucket group if force_single_bucket_group is True. + if force_single_bucket_group: + buckets = [] + ddp_config = buffers[0].ddp_config + data_parallel_group = buffers[0].data_parallel_group + data_parallel_world_size = buffers[0].data_parallel_world_size + for buffer in buffers: + assert ddp_config == buffer.ddp_config + assert data_parallel_group == buffer.data_parallel_group + assert data_parallel_world_size == buffer.data_parallel_world_size + buckets.extend(buffer.buckets) + + bucket_group = _ParamAndGradBucketGroup( + buckets, ddp_config, data_parallel_group, data_parallel_world_size + ) + return [bucket_group] + + if torch.uint8 not in dtype_to_buffer_map: + # Case 2: When there is no fp8 buffer in the input buffers, let each bucket group have + # only one bucket. + bucket_groups = [] + for buffer in buffers: + for bucket in buffer.buckets: + bucket_groups.append( + _ParamAndGradBucketGroup( + [bucket], + buffer.ddp_config, + buffer.data_parallel_group, + buffer.data_parallel_world_size, + ) + ) + return bucket_groups + else: + # Case 3: When using fp8 params, merge all non-fp8 buckets into the last fp8 bucket group. + non_fp8_buckets = [] + for buffer in buffers: + if buffer.param_dtype != torch.uint8: + for bucket in buffer.buckets: + non_fp8_buckets.append(bucket) + + bucket_groups = [] + fp8_buffer = dtype_to_buffer_map[torch.uint8] + for bucket in fp8_buffer.buckets: + if len(bucket_groups) == len(fp8_buffer.buckets) - 1: + # The last bucket group. + group_buckets = [bucket] + non_fp8_buckets + else: + # The first N-1 bucket groups. + group_buckets = [bucket] + bucket_groups.append( + _ParamAndGradBucketGroup( + group_buckets, + buffer.ddp_config, + buffer.data_parallel_group, + buffer.data_parallel_world_size, + ) + ) + return bucket_groups diff --git a/megatron/core/distributed/torch_fully_sharded_data_parallel.py b/megatron/core/distributed/torch_fully_sharded_data_parallel.py new file mode 100644 index 0000000000..73af5d46b2 --- /dev/null +++ b/megatron/core/distributed/torch_fully_sharded_data_parallel.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from typing import List + +import torch + +try: + from torch.distributed import DeviceMesh + from torch.distributed._composable.fsdp import fully_shard + + HAVE_FSDP = True +except ImportError: + HAVE_FSDP = False + +from megatron.core.utils import is_float8tensor + +from .. import parallel_state, tensor_parallel +from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from ..transformer.transformer_config import TransformerConfig +from ..transformer.transformer_layer import TransformerLayer +from .data_parallel_base import _BaseDataParallel + + +class TorchFullyShardedDataParallel(_BaseDataParallel): + """ + Enables fully sharded data parallelism by wrapping the given model with + the PyTorch FSDP2 API: + https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md + To utilize this class, PyTorch version >= 2.4.0 is required. + + Args: + config: Transformer config object. + module: Underlying model. + sub_modules_to_wrap: List of sub_modules to shard with FSDP. + Parameters within each sub_module will be all-gathered just-in-time. + The default list includes the following submodules derived from the + GPT model architecture: + TransformerLayer (all Transformer layers) + LanguageModelEmbedding (initial embedding layer) + RotaryEmbedding (initial RoPE layer) + tensor_parallel.ColumnParallelLinear (final output layer) + """ + + def __init__( + self, + config: TransformerConfig, + module: torch.nn.Module, + sub_modules_to_wrap: List[torch.nn.Module] = [ + TransformerLayer, + LanguageModelEmbedding, + RotaryEmbedding, + tensor_parallel.ColumnParallelLinear, + ], + **kwargs + ): + + assert ( + HAVE_FSDP + ), 'TorchFullyShardedDataParallel requires PyTorch >= 2.4.0 with FSDP 2 support.' + + super().__init__(config=config, module=module) + self.data_parallel_group = parallel_state.get_data_parallel_group( + with_context_parallel=True + ) + + mesh = DeviceMesh.from_group(self.data_parallel_group, "cuda") + + kwargs = {"mesh": mesh} + + def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + attrs = vars(param) + if is_float8tensor(param): + # disable fp8 transpose cache and perform transposing fp8 weights + # at each micro-batch because torch-FSDP doesn't recognize the + # micro-batch id, thus removing unnecessary memory stores + attrs['_fp8_attrs']['transpose_invalid'] = False + del attrs['_fp8_attrs']['transpose'] + custom_attrs[name] = {k: v for k, v in attrs.items()} + return custom_attrs + + def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + # Save the custom attributes on Parameters before FSDP overwrites them. + # See https://github.com/pytorch/pytorch/issues/136929. + attrs = save_custom_attrs(self.module) + + prev_module = None + for sub_module in self.module.modules(): + # Wrap individual submodules to fetch parameters just-in-time rather than + # conservatively fetching all parameters at the start of each iteration. + # See https://github.com/pytorch/pytorch/issues/114299. + if any( + isinstance(sub_module, sub_module_to_wrap) + for sub_module_to_wrap in sub_modules_to_wrap + ): + fully_shard(sub_module, **kwargs) + + # Explicitly set the FSDP backward prefetch schedule to prevent activation + # recomputation from disrupting the automatically generated default schedule. + if config.recompute_granularity is not None: + sub_module.set_modules_to_backward_prefetch( + [prev_module] if prev_module else [] + ) + prev_module = sub_module + + # Wrap the root module as required by the FSDP API. + # See https://github.com/pytorch/pytorch/issues/114299. + fully_shard(self.module, **kwargs) + + restore_custom_attrs(self.module, attrs) + + def load_state_dict(self, state_dict, strict=True): + """ + No-op because tensors are already loaded in-place by + `_load_base_checkpoint` with FSDP2.""" + pass diff --git a/megatron/core/export/__init__.py b/megatron/core/export/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/export/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/export/data_type.py b/megatron/core/export/data_type.py new file mode 100644 index 0000000000..38fbdea8f6 --- /dev/null +++ b/megatron/core/export/data_type.py @@ -0,0 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from enum import Enum + +DataType = Enum('DataType', ["bfloat16", "float16", "float32"]) diff --git a/megatron/core/export/export_config.py b/megatron/core/export/export_config.py new file mode 100644 index 0000000000..2cc1e208be --- /dev/null +++ b/megatron/core/export/export_config.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass + + +@dataclass +class ExportConfig: + """Base configuration for Megatron Core Export + + These parameters control the export setting for trtllm + """ + + inference_tp_size: int = 1 + + inference_pp_size: int = 1 + + use_parallel_embedding: bool = False + + use_embedding_sharing: bool = False diff --git a/megatron/core/export/model_type.py b/megatron/core/export/model_type.py new file mode 100644 index 0000000000..6a33d6440e --- /dev/null +++ b/megatron/core/export/model_type.py @@ -0,0 +1,7 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from enum import Enum + +ModelType = Enum( + 'ModelType', ["gpt", "gptnext", "llama", "falcon", "starcoder", "mixtral", "gemma"] +) diff --git a/megatron/core/export/trtllm/__init__.py b/megatron/core/export/trtllm/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/export/trtllm/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/export/trtllm/engine_builder/__init__.py b/megatron/core/export/trtllm/engine_builder/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/export/trtllm/engine_builder/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py b/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py new file mode 100644 index 0000000000..df8ea627b7 --- /dev/null +++ b/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py @@ -0,0 +1,154 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import tensorrt_llm +from tensorrt_llm._common import check_max_num_tokens +from tensorrt_llm.builder import BuildConfig +from tensorrt_llm.commands.build import build as build_trtllm +from tensorrt_llm.logger import logger +from tensorrt_llm.lora_manager import LoraConfig +from tensorrt_llm.models.modeling_utils import optimize_model, preprocess_weights +from tensorrt_llm.plugin import PluginConfig + + +class TRTLLMEngineBuilder: + """A utility class to build TRTLLM engine""" + + @staticmethod + def build_and_save_engine( + engine_dir: str, + trtllm_model_weights: dict, + trtllm_model_config, + max_input_len: int = 1024, + max_output_len: int = 1024, + max_batch_size: int = 4, + lora_ckpt_list=None, + use_lora_plugin=None, + max_lora_rank: int = 64, + lora_target_modules=None, + max_prompt_embedding_table_size: int = 0, + paged_kv_cache: bool = True, + remove_input_padding: bool = True, + paged_context_fmha: bool = False, + use_refit: bool = False, + max_num_tokens: int = None, + max_seq_len: int = None, + opt_num_tokens: int = None, + max_beam_width: int = 1, + tokens_per_block: int = 128, + multiple_profiles: bool = False, + gpt_attention_plugin: str = "auto", + gemm_plugin: str = "auto", + reduce_fusion: bool = False, + ): + """Method to build the TRTLLM Engine + + This method uses the TRTLLMEngineBuilder to build and save the engine to engine dir + + Args: + engine_dir (str): The file path to save the engine + trtllm_model_weights (dict): The TRTLLM converted model weights dict + trtllm_model_config : The TRTLLM Config + max_input_len (int, optional): Max input length. Defaults to 1024. + max_output_len (int, optional): Max output length. Defaults to 1024. + max_batch_size (int, optional): Max batch size. Defaults to 4. + model_type (ModelType, optional): ModelType enum. Defaults to ModelType.gpt. + lora_ckpt_list (_type_, optional): Lora checkpoint list. Defaults to None. + use_lora_plugin (_type_, optional): Use lora plugin. Defaults to None. + max_lora_rank (int, optional): Max lora rank. Defaults to 64. + lora_target_modules (_type_, optional): Lora target modules. Defaults to None. + max_prompt_embedding_table_size (int, optional): Defaults to 0. + paged_kv_cache (bool, optional): Use Paged KV cache. Defaults to True. + remove_input_padding (bool, optional): Remove input padding. Defaults to True. + paged_context_fmha (bool, optional): Paged context fmha. Defaults to False. + use_refit (bool, optional): Use refit. Defaults to False. + max_num_tokens (int, optional): Max num of tokens. Defaults to None. + max_seq_len (int, optional): Max seq length. Defaults to None. + opt_num_tokens (int, optional): Opt number of tokens. Defaults to None. + max_beam_width (int, optional): Max beam width. Defaults to 1. + tokens_per_block (int, optional): Nmber of tokens per block. Defaults to 128. + multiple_profiles (bool, optional): Use multiple profiles. Defaults to False. + gpt_attention_plugin (str, optional): Gpt attention plugin to use. Defaults to "auto". + gemm_plugin (str, optional): Gemma plugin to use. Defaults to "auto". + """ + architecture = ( + "LLaMAForCausalLM" + if trtllm_model_config.architecture == "LlamaForCausalLM" + else trtllm_model_config.architecture + ) + try: + model_cls = getattr(tensorrt_llm.models, architecture) + except: + raise AttributeError(f"Could not find TRTLLM model for architecture: {architecture}!") + + logger.set_level("info") + plugin_config = PluginConfig() + plugin_config.gpt_attention_plugin = gpt_attention_plugin + plugin_config.gemm_plugin = gemm_plugin + if paged_kv_cache: + plugin_config.enable_paged_kv_cache(tokens_per_block=tokens_per_block) + else: + plugin_config.paged_kv_cache = False + plugin_config.remove_input_padding = remove_input_padding + plugin_config.use_paged_context_fmha = paged_context_fmha + plugin_config.multiple_profiles = multiple_profiles + plugin_config.reduce_fusion = reduce_fusion + + if max_seq_len is None: + max_seq_len = max_input_len + max_output_len + + max_num_tokens, opt_num_tokens = check_max_num_tokens( + max_num_tokens=max_num_tokens, + opt_num_tokens=opt_num_tokens, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + max_input_len=max_input_len, + max_beam_width=max_beam_width, + remove_input_padding=remove_input_padding, + enable_context_fmha=plugin_config.context_fmha, + tokens_per_block=tokens_per_block, + multiple_profiles=multiple_profiles, + ) + + build_dict = { + 'max_input_len': max_input_len, + 'max_output_len': max_output_len, + 'max_batch_size': max_batch_size, + 'max_beam_width': max_beam_width, + 'max_seq_len': max_seq_len, + 'max_num_tokens': max_num_tokens, + 'opt_num_tokens': opt_num_tokens, + 'max_prompt_embedding_table_size': max_prompt_embedding_table_size, + 'gather_context_logits': False, + 'gather_generation_logits': False, + 'strongly_typed': False, + 'builder_opt': None, + 'use_refit': use_refit, + 'multiple_profiles': multiple_profiles, + } + build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config) + + if use_lora_plugin is not None: + # build_config.plugin_config.set_lora_plugin(use_lora_plugin) + # build_config.plugin_config._lora_plugin = use_lora_plugin + lora_config = LoraConfig( + lora_dir=lora_ckpt_list, + lora_ckpt_source='nemo', # TODO : NEED TO SEE HOW TO HANDLE THIS FOR MCORE + max_lora_rank=max_lora_rank, + lora_target_modules=lora_target_modules, + ) + build_config.lora_config = lora_config + + model = model_cls.from_config(trtllm_model_config) + + model = optimize_model( + model, + use_parallel_embedding=trtllm_model_config.use_parallel_embedding, + share_embedding_table=trtllm_model_config.share_embedding_table, + ) + + preprocess_weights(trtllm_model_weights, trtllm_model_config) + model.load(trtllm_model_weights) + engine = build_trtllm(model, build_config) + + engine.save(engine_dir) + return engine diff --git a/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py b/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py b/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py new file mode 100644 index 0000000000..d3cd7ff296 --- /dev/null +++ b/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers + +# Map the most common mcore layers to TRTLLM layers +# pylint: disable=line-too-long +DEFAULT_CONVERSION_DICT = { + # INPUT + 'embedding.word_embeddings.weight': TRTLLMLayers.vocab_embedding, + 'embedding.position_embeddings.weight': TRTLLMLayers.position_embedding, + # ATTENTION + 'decoder.layers.input_layernorm.weight': TRTLLMLayers.input_layernorm_weight, + 'decoder.layers.input_layernorm.bias': TRTLLMLayers.input_layernorm_bias, + 'decoder.layers.self_attention.linear_qkv.weight': TRTLLMLayers.attention_qkv_weight, + 'decoder.layers.self_attention.linear_qkv.bias': TRTLLMLayers.attention_qkv_bias, + 'decoder.layers.self_attention.linear_proj.weight': TRTLLMLayers.attention_dense_weight, + 'decoder.layers.self_attention.linear_proj.bias': TRTLLMLayers.attention_dense_bias, + # MLP + 'decoder.layers.pre_mlp_layernorm.weight': TRTLLMLayers.post_layernorm_weight, + 'decoder.layers.pre_mlp_layernorm.bias': TRTLLMLayers.post_layernorm_bias, + 'decoder.layers.mlp.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight, + 'decoder.layers.mlp.linear_fc1.bias': TRTLLMLayers.mlp_fc_bias, + 'decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight, + 'decoder.layers.mlp.linear_fc2.bias': TRTLLMLayers.mlp_projection_bias, + # EXPERTS + 'decoder.layers.mlp.experts.experts.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight_mixture_of_experts, + 'decoder.layers.mlp.experts.experts.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight_mixture_of_experts, + 'decoder.layers.mlp.router.weight': TRTLLMLayers.mlp_router_weight, + # FINAL LAYER NORM + 'decoder.final_layernorm.weight': TRTLLMLayers.final_layernorm_weight, + 'decoder.final_layernorm.bias': TRTLLMLayers.final_layernorm_bias, + # OUTPUT LAYER + 'output_layer.weight': TRTLLMLayers.lm_head, + # TRANSFORMER ENGINE LAYER NORM + # ATTENTION + 'decoder.layers.self_attention.linear_qkv.layer_norm_weight': TRTLLMLayers.input_layernorm_weight, + 'decoder.layers.self_attention.linear_qkv.layer_norm_bias': TRTLLMLayers.input_layernorm_bias, + # MLP + 'decoder.layers.mlp.linear_fc1.layer_norm_weight': TRTLLMLayers.post_layernorm_weight, + 'decoder.layers.mlp.linear_fc1.layer_norm_bias': TRTLLMLayers.post_layernorm_bias, +} diff --git a/megatron/core/export/trtllm/trt_model_config.py b/megatron/core/export/trtllm/trt_model_config.py new file mode 100644 index 0000000000..2ed09398c2 --- /dev/null +++ b/megatron/core/export/trtllm/trt_model_config.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import tensorrt_llm + +from megatron.core.export.model_type import ModelType + +TRT_MODEL_CONFIG = { + ModelType.gpt: tensorrt_llm.models.gpt.config.GPTConfig, + ModelType.gptnext: tensorrt_llm.models.gpt.config.GPTConfig, + ModelType.starcoder: tensorrt_llm.models.gpt.config.GPTConfig, + ModelType.mixtral: tensorrt_llm.models.llama.config.LLaMAConfig, + ModelType.llama: tensorrt_llm.models.llama.config.LLaMAConfig, + ModelType.gemma: tensorrt_llm.models.GemmaConfig, + ModelType.falcon: tensorrt_llm.models.falcon.config.FalconConfig, +} diff --git a/megatron/core/export/trtllm/trt_model_type.py b/megatron/core/export/trtllm/trt_model_type.py new file mode 100644 index 0000000000..f45ff1786e --- /dev/null +++ b/megatron/core/export/trtllm/trt_model_type.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.export.model_type import ModelType + +TRT_MODEL_TYPE_STRING = { + ModelType.gpt: 'GPTForCausalLM', + ModelType.gptnext: 'GPTForCausalLM', + ModelType.starcoder: 'GPTForCausalLM', + ModelType.mixtral: 'LlamaForCausalLM', + ModelType.llama: 'LlamaForCausalLM', + ModelType.gemma: 'GemmaForCausalLM', + ModelType.falcon: 'FalconForCausalLM', +} diff --git a/megatron/core/export/trtllm/trtllm_helper.py b/megatron/core/export/trtllm/trtllm_helper.py new file mode 100644 index 0000000000..45093b673d --- /dev/null +++ b/megatron/core/export/trtllm/trtllm_helper.py @@ -0,0 +1,588 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from typing import Union + +import tensorrt_llm +import torch +from tensorrt_llm.functional import non_gated_version +from tensorrt_llm.layers import MoeConfig + +from megatron.core.export.data_type import DataType +from megatron.core.export.export_config import ExportConfig +from megatron.core.export.model_type import ModelType +from megatron.core.export.trtllm.engine_builder.trtllm_engine_builder import TRTLLMEngineBuilder +from megatron.core.export.trtllm.model_to_trllm_mapping.default_conversion_dict import ( + DEFAULT_CONVERSION_DICT, +) +from megatron.core.export.trtllm.trt_model_config import TRT_MODEL_CONFIG +from megatron.core.export.trtllm.trt_model_type import TRT_MODEL_TYPE_STRING +from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers + +# pylint: disable=line-too-long +from megatron.core.export.trtllm.trtllm_weights_converter.distributed_trtllm_model_weights_converter import ( + DistributedTRTLLMModelWeightsConverter, +) +from megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter import ( + SingleDeviceTRTLLMModelWeightsConverter, +) +from megatron.core.transformer.transformer_config import TransformerConfig + + +class TRTLLMHelper: + """TRTLLM Helper class to convert export and build TRTLLM model.""" + + def __init__( + self, + transformer_config: TransformerConfig, + model_type: ModelType, + trtllm_conversion_dict: dict = {}, + position_embedding_type: str = 'learned_absolute', + max_position_embeddings: int = None, + rotary_percentage: int = 1.0, + rotary_base: int = 10000, + moe_tp_mode: int = 2, + multi_query_mode: bool = False, + activation: str = "gelu", + seq_len_interpolation_factor: float = None, + moe_renorm_mode=None, + share_embeddings_and_output_weights=False, + ): + """Constructor for the TRTLLMHelper + + There are two public API's supported by this helper. + a) get_trtllm_pretrained_config_and_model_weights + b) build_and_save_engine + + Args: + transformer_config (TransformerConfig): The transformer config + model_type (ModelType): The type of the input model. Enum (megatron.core.export.model_type.ModelType) + trtllm_conversion_dict (dict, optional): A conversion dictionary that will map your model layer names to trtllm equivalent layer names. Default dictionary is given megatron/core/export/model_to_trtllm_mapping. This dict is merged into the default dict. NOTE: Ignore layer numbers in the model layer names. (e.g) decoder.layers.0.attention_qkv.weight will be decoder.layers.attention_qkv.weight in the mapping dictionary. Defaults to {}. + position_embedding_type (str, optional): The position embedding type. Defaults to None. + max_position_embeddings (int, optional): Max posistion embeddings value. Defaults to None. + rotary_percentage (int, optional): The rotary percentage if using rope embedding. Defaults to 1.0. + rotary_base (int, optional): The rotary base (theta value) if using rope embeddings. Defaults to 10000. + moe_tp_mode (int, optional): TRTLLM Config. Defaults to 2. + multi_query_mode (bool, optional): Defaults to False. + activation (str, optional): Defaults to "gelu". + seq_len_interpolation_factor (float, optional): The sequence length interpolation factor if using rope embeddings. Defaults to None. + moe_renorm_mode (optional) : Renormalization mode if using mixture of experts. Defaults to None. + share_embeddings_and_output_weights (bool, optional): True if input and output layers share weights. Defaults to False. + """ + + self.transformer_config = transformer_config + self.model_type = model_type + self.trtllm_conversion_dict = DEFAULT_CONVERSION_DICT.copy() + self.trtllm_conversion_dict.update(trtllm_conversion_dict) + assert position_embedding_type in [ + 'learned_absolute', + 'rope', + ], f"Position embedding type should be one of learned_absolute, rope. You entered {position_embedding_type}" + self.position_embedding_type = position_embedding_type + self.max_position_embeddings = max_position_embeddings + self.rotary_percentage = rotary_percentage + self.rotary_base = rotary_base + self.moe_tp_mode = moe_tp_mode + self.multi_query_mode = multi_query_mode + self.activation = activation + self.seq_len_interpolation_factor = seq_len_interpolation_factor + self.moe_renorm_mode = moe_renorm_mode + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.weights_converter = None + + def _get_trtllm_config( + self, + export_config: ExportConfig, + world_size: int, + gpus_per_node: int, + vocab_size_padded: int, + dtype: DataType, + fp8_quantized: bool = False, + fp8_kvcache: bool = False, + ): + """Get TRTLLM Config + + Returns appropriate TRTLLM PretrainedConfig used by TRTLLM for building engine + + Args: + export_config (ExportConfig): The export config that defines inference tp , pp size etc. + world_size (int): The number of gpus (Mostly TP * PP) + gpus_per_node (int): Num gpus per node + vocab_size_padded (int): Padded vocab size + dtype (DataType): The datatype or model precision + + Returns: + GPTConfig or the LLamaConfig or the PretrainedConfig constructed from your model config + """ + hidden_act = self.activation + hidden_act = ( + hidden_act.split("-")[-1] + if self.transformer_config.num_moe_experts + else non_gated_version(hidden_act) + ) + + config = { + 'architecture': TRT_MODEL_TYPE_STRING[self.model_type], + 'dtype': dtype.name, + 'num_hidden_layers': self.transformer_config.num_layers, + 'num_attention_heads': self.transformer_config.num_attention_heads, + 'num_key_value_heads': ( + self.transformer_config.num_query_groups + if self.transformer_config.num_query_groups + else self.transformer_config.num_attention_heads + ), + 'head_size': self.transformer_config.kv_channels, + 'hidden_size': self.transformer_config.hidden_size, + 'intermediate_size': self.transformer_config.ffn_hidden_size, + 'norm_epsilon': self.transformer_config.layernorm_epsilon, + 'vocab_size': vocab_size_padded, + 'position_embedding_type': ( + "rope_gpt_neox" if self.position_embedding_type == "rope" else "learned_absolute" + ), + 'max_position_embeddings': self.max_position_embeddings, + 'hidden_act': hidden_act, + 'use_parallel_embedding': export_config.use_parallel_embedding, + 'embedding_sharding_dim': 0, + 'share_embedding_table': export_config.use_embedding_sharing, + 'quantization': { + 'quant_algo': "FP8" if fp8_quantized else None, + 'kv_cache_quant_algo': "FP8" if fp8_kvcache else None, + }, + 'bias': self.transformer_config.add_bias_linear, + 'apply_query_key_layer_scaling': False, + 'rotary_pct': self.rotary_percentage, + 'rotary_base': self.rotary_base, + 'moe_num_experts': ( + 0 + if self.transformer_config.moe_router_topk == 0 + else (self.transformer_config.num_moe_experts or 1) + ), + 'moe_top_k': self.transformer_config.moe_router_topk, + 'moe_normalization_mode': self.moe_renorm_mode + or MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE, + 'moe_tp_mode': self.moe_tp_mode, + 'logits_dtype': 'float32', + 'world_size': world_size, + 'tp_size': export_config.inference_tp_size, + 'pp_size': export_config.inference_pp_size, + 'gpus_per_node': gpus_per_node, + } + + if self.model_type == ModelType.falcon: + config["new_decoder_architecture"] = ( + False if self.transformer_config.num_layers == 32 else True + ) + config["parallel_attention"] = True + + if self.seq_len_interpolation_factor is not None: + config["rotary_scaling"] = { + "type": "linear", + "factor": float(self.seq_len_interpolation_factor), + } + + config_cls = TRT_MODEL_CONFIG[self.model_type] + return config_cls(**config) + + def _load_scaling_factors(self, model_state_dict: dict) -> dict: + """Loads scaling factors from model state dictionary. + + Args: + model_state_dict (dict): Model state dictionary + Returns: + dict: Maps scaling factor key, to its value and the inverse. The inverse is used for casting the quantized weights. + """ + weight_scaling_suffix = '.weights_scaling_factor' + activation_scaling_suffix = '.activation_scaling_factor' + mock_scales_dict = {} + extra_state_infix = "._extra_state" + mock_suffix = '.weight' + + for key, val in model_state_dict.items(): + if extra_state_infix in key and not key.endswith("core_attention._extra_state"): + mock_key = key.split(extra_state_infix)[0] + mock_suffix + mock_scales_dict[mock_key] = val + + mock_scales_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names( + mock_scales_dict, self.trtllm_conversion_dict, False + ) + split_gated_activation = self.activation in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"] + + scales = {} + for key, val in mock_scales_dict.items(): + if val is None: + continue + + val.seek(0) + extra_states = torch.load(val) + + activation_scaling_factor_key = key.replace(mock_suffix, activation_scaling_suffix) + weight_scaling_factor_key = key.replace(mock_suffix, weight_scaling_suffix) + + activation_scales = { + 'trt_llm_scale': extra_states['scale_inv_fwd'][0].view(1), + 'weight_multiplier': extra_states['scale_fwd'][0].view(1), + } + + weight_scales = { + 'trt_llm_scale': extra_states['scale_inv_fwd'][1].view(1), + 'weight_multiplier': extra_states['scale_fwd'][1].view(1), + } + + scales[activation_scaling_factor_key] = activation_scales + scales[weight_scaling_factor_key] = weight_scales + if split_gated_activation and ".mlp.fc" in key: + scales[activation_scaling_factor_key.replace("fc", "gate")] = activation_scales + scales[weight_scaling_factor_key.replace("fc", "gate")] = weight_scales + + return scales + + # pylint: disable=line-too-long + def get_trtllm_pretrained_config_and_model_weights( + self, + model_state_dict, + dtype: DataType, + export_config: ExportConfig = None, + on_device_distributed_conversion: bool = False, + vocab_size: int = None, + gpus_per_node: int = None, + state_dict_split_by_layer_numbers: bool = True, + fp8_quantized: bool = False, + fp8_kvcache: bool = False, + ): + """Get TRTLLM Config and Converted Model Weights + + This function returns the trtllm model weights as a list. + There are two modes for conversion. The default is to use a single device cpu/gpu for conversion. + NOTE: For faster performance, if your entire model will fit in memory, pre transfer the model state dict to cuda device and then call this function. + For on device conversion it returns weights which will be used on the device itself. + Same thing happens with the pretrained config + + Args: + model_state_dict (dict): The input model state dictionary (Entire model state loaded on CPU) or the model state dict of each GPU in the case of on_device conversion) + export_config (ExportConfig): The export config used to define inference tp size, pp size etc. Used only for on device conversion. + dtype (DataType): The data type of model precision + on_device_distributed_conversion (bool, optional): Convert on gpus in distributed setting. This assumes that the model state dict is sharded according to required inference model parallelism and that each gpu gets its part of the model state dict . Defaults to False. + vocab_size (int, optional): The vocabulary size. Defaults to None. + gpus_per_node (int, optional): The number of gpus per node. Used for on device conversion. + state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True + + Returns: + Two lists . First list of trtllm converted model weights(Either on device, or a list of weights for each gpu) and the trtllm_model_configs. + """ + assert model_state_dict is not None, "Model state dict is not set" + + scales = self._load_scaling_factors(model_state_dict) if fp8_quantized else {} + model_state_dict = {k: v for k, v in model_state_dict.items() if 'extra_state' not in k} + + if on_device_distributed_conversion: + assert vocab_size is not None, "Need to pass in vocab_size for on device" + supported_model = self.model_type in [ModelType.gpt, ModelType.gptnext, ModelType.llama] + assert ( + supported_model + ), "On device conversion only supported for model types gptnext and llama" + assert export_config is None, ( + "Export config is inferred based on the parallel state. " + "If you want to set inference tp 2, then load the model with this TP2 setting and just pass in the model state dict." + ) + + assert ( + gpus_per_node is not None + ), "Need to pass in gpus_per_node for on device conversion" + trtllm_model_weights_on_device, trtllm_model_config = ( + self._get_trtllm_pretrained_config_and_model_weights_in_distributed_setting( + model_state_dict, + dtype, + vocab_size, + gpus_per_node, + scales, + fp8_quantized, + fp8_kvcache, + ) + ) + return [trtllm_model_weights_on_device], [trtllm_model_config] + + else: + assert not ( + self.share_embeddings_and_output_weights and not export_config.use_embedding_sharing + ), "Found share_embeddings_and_output_weights is True in the model. So set export_config.use_embedding_sharing to True" + assert ( + vocab_size is None + ), "Vocab size is inferred from the input layer for cpu conversion. So leave it as None" + trtllm_model_weights_list, trtllm_model_config_list = ( + self._get_trtllm_pretrained_config_and_model_weights_list_on_single_device( + export_config, + model_state_dict, + dtype, + gpus_per_node, + state_dict_split_by_layer_numbers, + scales, + fp8_quantized, + fp8_kvcache, + ) + ) + + return trtllm_model_weights_list, trtllm_model_config_list + + def _add_scales_to_converter( + self, + converter: Union[ + SingleDeviceTRTLLMModelWeightsConverter, DistributedTRTLLMModelWeightsConverter + ], + scales: dict, + fp8_kvcache: bool, + ): + """Adds scaling factors to the distributed and single device converters. + + Args: + converter (ModelWeightConverter): Converter, holding the TRT-LLM model weights. + scales (dict): Dictionary holding TRT-LLM scaling factors + fp8_kvcache (bool): If true, creates scaling factors (equal to 1.0) for kv_cache quantization + """ + trt_scales = {key: scale['trt_llm_scale'] for key, scale in scales.items()} + kv_scales = {} + if fp8_kvcache: + for key in converter.trtllm_model_weights: + if '.attention.qkv.weight' in key: + kv_key = key.split('.qkv')[0] + '.kv_cache_scaling_factor' + kv_scales[kv_key] = torch.tensor([1.0], dtype=torch.float32) + + converter.trtllm_model_weights |= trt_scales | kv_scales + + def _get_trtllm_pretrained_config_and_model_weights_in_distributed_setting( + self, + model_state_dict: dict, + dtype: DataType, + vocab_size: int, + gpus_per_node: int, + scales: dict, + fp8_quantized: bool, + fp8_kvcache: bool, + ): + """Get the TRTLLM Pretrained config and model weights list in a distributed setting + + This function assumes the model state dict is distributed according to model parallelism . + Each device gets its own model state dict + + Args: + export_config (ExportConfig): The export config to set inference tp, pp size etc. + model_state_dict (dict): The model state dictionary (All collected on cpu) + dtype (DataType): The data type or model precision + vocab_size (int): Tokenizer vocab size + gpus_per_node (int): The number of gpus per node + scales (dict): Dictionary with fp8 scaling factors + fp8_quantized (bool): True for fp8 checkpoint export + fp8_kvcache (bool): True for fp8 KV-cache quantization + Returns: + Two lists . List of trtllm converted model weights and trtllm model configs (One for each gpu). + """ + + self.weights_converter = DistributedTRTLLMModelWeightsConverter( + transformer_config=self.transformer_config, + dtype=dtype, + multi_query_mode=self.multi_query_mode, + activation=self.activation, + scales=scales, + ) + self.weights_converter.convert( + model_state_dict=model_state_dict, + trtllm_conversion_dict=self.trtllm_conversion_dict, + tokenizer_vocab_size=vocab_size, + ) + self._add_scales_to_converter(self.weights_converter, scales, fp8_kvcache) + + export_config = ExportConfig( + inference_pp_size=self.weights_converter.inference_pp_size, + inference_tp_size=self.weights_converter.inference_tp_size, + use_parallel_embedding=True, + use_embedding_sharing=self.share_embeddings_and_output_weights, + ) + + world_size = export_config.inference_tp_size * export_config.inference_pp_size + + trtllm_model_config = self._get_trtllm_config( + export_config=export_config, + world_size=world_size, + gpus_per_node=gpus_per_node, + vocab_size_padded=vocab_size, + dtype=dtype, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache, + ) + + model_parallel_rank = ( + self.weights_converter.pp_rank * self.weights_converter.inference_tp_size + + self.weights_converter.tp_rank + ) + + trtllm_model_config.mapping = tensorrt_llm.Mapping( + world_size=world_size, + rank=model_parallel_rank, + tp_size=export_config.inference_tp_size, + pp_size=export_config.inference_pp_size, + ) + + return self.weights_converter.trtllm_model_weights, trtllm_model_config + + def _get_trtllm_pretrained_config_and_model_weights_list_on_single_device( + self, + export_config: ExportConfig, + model_state_dict: dict, + dtype: DataType, + gpus_per_node, + state_dict_split_by_layer_numbers, + scales: dict, + fp8_quantized: bool, + fp8_kvcache: bool, + ): + """Get the TRTLLM Pretrained config and model weights list (one per gpu rank) on single device (CPU/GPU) + + This function assumes the entire model state dict is present in CPU or on one GPU + + Args: + export_config (ExportConfig): The export config to set inference tp, pp size etc. + model_state_dict (dict): The model state dictionary (All collected on cpu) + dtype (DataType): The data type or model precision + gpus_per_node (int, optional): Number of gpus per node + state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True + scales (dict): Dictionary with fp8 scaling factors + fp8_quantized (bool): True for fp8 checkpoint export + fp8_kvcache (bool): True for fp8 KV-cache quantization + + Returns: + Two lists . List of trtllm converted model weights and trtllm model configs (One for each gpu). + """ + trtllm_model_configs_list = [] + trtllm_model_weights_list = [] + + self.weights_converter = SingleDeviceTRTLLMModelWeightsConverter( + export_config=export_config, + transformer_config=self.transformer_config, + dtype=dtype, + activation=self.activation, + multi_query_mode=self.multi_query_mode, + scales=scales, + ) + # Convert the input model state dict to trtllm model weights dictionary + self.weights_converter.convert( + model_state_dict=model_state_dict, + trtllm_conversion_dict=self.trtllm_conversion_dict, + state_dict_split_by_layer_numbers=state_dict_split_by_layer_numbers, + ) + + self._add_scales_to_converter(self.weights_converter, scales, fp8_kvcache) + + vocab_size_padded = self.weights_converter.get_padded_vocab_size() + world_size = export_config.inference_tp_size * export_config.inference_pp_size + gpus_per_node = gpus_per_node or export_config.inference_tp_size + + for gpu_rank in range(world_size): + mapping = tensorrt_llm.Mapping( + world_size=world_size, + rank=gpu_rank, + tp_size=export_config.inference_tp_size, + pp_size=export_config.inference_pp_size, + ) + + # Important to create a new instance everytime so that the list elements have differnt rank values in the mapping object + trtllm_model_config = self._get_trtllm_config( + export_config=export_config, + world_size=world_size, + gpus_per_node=gpus_per_node, + vocab_size_padded=vocab_size_padded, + dtype=dtype, + fp8_quantized=fp8_quantized, + fp8_kvcache=fp8_kvcache, + ) + trtllm_model_config.mapping = mapping + trtllm_model_configs_list.append(trtllm_model_config) + + # Get the model weights for each rank and append it to the trtllm_model_weights_list + trtllm_model_weights_per_gpu = self.weights_converter.get_local_model_weights_per_gpu( + mapping, trtllm_model_config + ) + trtllm_model_weights_list.append(trtllm_model_weights_per_gpu) + + return trtllm_model_weights_list, trtllm_model_configs_list + + def build_and_save_engine( + self, + engine_dir: str, + trtllm_model_weights: dict, + trtllm_model_config, + max_input_len: int = 1024, + max_output_len: int = 1024, + max_batch_size: int = 4, + lora_ckpt_list=None, + use_lora_plugin=None, + max_lora_rank: int = 64, + lora_target_modules=None, + max_prompt_embedding_table_size: int = 0, + paged_kv_cache: bool = True, + remove_input_padding: bool = True, + paged_context_fmha: bool = False, + use_refit: bool = False, + max_num_tokens: int = None, + max_seq_len: int = None, + opt_num_tokens: int = None, + max_beam_width: int = 1, + tokens_per_block: int = 128, + multiple_profiles: bool = False, + gpt_attention_plugin: str = "auto", + gemm_plugin: str = "auto", + ): + """Method to build the TRTLLM Engine + + This method uses the TRTLLMEngineBuilder to build and save the engine to engine dir + + Args: + engine_dir (str): The file path to save the engine + trtllm_model_weights (dict): The TRTLLM converted model weights dict + trtllm_model_config : The TRTLLM Config + max_input_len (int, optional): Max input length. Defaults to 1024. + max_output_len (int, optional): Max output length. Defaults to 1024. + max_batch_size (int, optional): Max batch size. Defaults to 4. + lora_ckpt_list (_type_, optional): Lora checkpoint list. Defaults to None. + use_lora_plugin (_type_, optional): Use lora plugin. Defaults to None. + max_lora_rank (int, optional): Max lora rank. Defaults to 64. + lora_target_modules (_type_, optional): Lora target modules. Defaults to None. + max_prompt_embedding_table_size (int, optional): Max size of prompt embedding table. Defaults to 0. + paged_kv_cache (bool, optional): Use Paged KV cache. Defaults to True. + remove_input_padding (bool, optional): Remove input padding. Defaults to True. + paged_context_fmha (bool, optional): Paged context fmha. Defaults to False. + use_refit (bool, optional): Use refit. Defaults to False. + max_num_tokens (int, optional): Max num of tokens. Defaults to None. + max_seq_len (int, optional): Max seq length. Defaults to None. + opt_num_tokens (int, optional): Opt number of tokens. Defaults to None. + max_beam_width (int, optional): Max beam width. Defaults to 1. + tokens_per_block (int, optional): Nmber of tokens per block. Defaults to 128. + multiple_profiles (bool, optional): Use multiple profiles. Defaults to False. + gpt_attention_plugin (str, optional): Gpt attention plugin to use. Defaults to "auto". + gemm_plugin (str, optional): Gemma plugin to use. Defaults to "auto". + """ + + engine = TRTLLMEngineBuilder.build_and_save_engine( + engine_dir, + trtllm_model_weights, + trtllm_model_config, + max_input_len, + max_output_len, + max_batch_size, + lora_ckpt_list, + use_lora_plugin, + max_lora_rank, + lora_target_modules, + max_prompt_embedding_table_size, + paged_kv_cache, + remove_input_padding, + paged_context_fmha, + use_refit, + max_num_tokens, + max_seq_len, + opt_num_tokens, + max_beam_width, + tokens_per_block, + multiple_profiles, + gpt_attention_plugin, + gemm_plugin, + ) + + return engine diff --git a/megatron/core/export/trtllm/trtllm_layers.py b/megatron/core/export/trtllm/trtllm_layers.py new file mode 100644 index 0000000000..0cf805dcb6 --- /dev/null +++ b/megatron/core/export/trtllm/trtllm_layers.py @@ -0,0 +1,157 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import re +from enum import Enum +from typing import Tuple + + +class TRTLLMLayers(Enum): + """TRTLLM Layer names + + This Enum will be used to map input model layer names to TRTLLM Layer names + """ + + # ONE TIME LAYERS (NOT ASSOCIATED TO TRANSFORMER BLOCK) + # Input layers + position_embedding = 'transformer.position_embedding.weight' + vocab_embedding = 'transformer.vocab_embedding.weight' + lm_head = 'lm_head.weight' + + # Output layers + final_layernorm_weight = 'transformer.ln_f.weight' + final_layernorm_bias = 'transformer.ln_f.bias' + + # TRANSFORMER LAYERS + # Attention block related layers + input_layernorm_weight = 'transformer.layers.input_layernorm.weight' + input_layernorm_bias = 'transformer.layers.input_layernorm.bias' + attention_qkv_weight = 'transformer.layers.attention.qkv.weight' + attention_qkv_bias = 'transformer.layers.attention.qkv.bias' + attention_dense_weight = 'transformer.layers.attention.dense.weight' + attention_dense_bias = 'transformer.layers.attention.dense.bias' + + # mlp layers + mlp_fc_weight = 'transformer.layers.mlp.fc.weight' + mlp_fc_bias = 'transformer.layers.mlp.fc.bias' + post_layernorm_weight = 'transformer.layers.post_layernorm.weight' + post_layernorm_bias = 'transformer.layers.post_layernorm.bias' + mlp_projection_weight = 'transformer.layers.mlp.proj.weight' + mlp_projection_bias = 'transformer.layers.mlp.proj.bias' + + # mixture of expert layers + mlp_router_weight = 'transformer.layers.mlp.router.weight' + mlp_fc_weight_mixture_of_experts = 'transformer.layers.mlp.fc.weight.expert' + mlp_projection_weight_mixture_of_experts = 'transformer.layers.mlp.proj.weight.expert' + + @staticmethod + def return_layer_name_and_number(layer_name: str) -> Tuple[str, int]: + """Helper function to return layer name and number + Given an input layer e.g decoder.layers.2.self_attention.linear_qkv.weight, + this function returns decoder.layers.self_attention.linear_qkv.weight and layernumber 2. + In case no layer number is present, it returns None for the layer number + Args: + layer_name (dict): The input layer name + + Returns: + Tuple[str, int]: The layer name , layer number (layer number could be None) + """ + # Use regular expression to find the number specifically after 'layers.' + match = re.search(r'(?<=layers\.)\d+(?=\.)', layer_name) + if match: + # Extract the number and remove it from the layer name + number = match.group(0) + layer_name_without_number = re.sub(r'\.{}\.'.format(number), '.', layer_name) + return layer_name_without_number, int(number) + else: + # Return the original name if no number is found + return layer_name, None + + # pylint: disable=line-too-long + @staticmethod + def rename_input_layer_names_to_trtllm_layer_names( + model_state_dict: dict, + trtllm_conversion_dict: dict, + state_dict_split_by_layer_numbers: bool = True, + ) -> dict: + """Helper function to rename model layer names to TRTLLM Layer names + + We go through each layer (keys) in the model state dict, + and map it to the equivalent TRTLLMLayer name (megatron/core/export/trtllm/trtllm). + If we have a layer number associated with layer, we extract it out, + map the original layer name to equivalent trtllm layer name and add layer number back. + CPU Conversion will pass in model state dict without layer numbers + (i.e decoder.layers.mlp.linear_fc1.weight of shape [num_layers, hidden_dim, 4 * hidden_dim]) . + GPU conversion will pass model state dict with each layer seperated + (i.e decoder.layers.2.mlp.linear_fc1.weight of shape [hidden_dim, 4 * hidden_dim]). + + Args: + model_state_dict (dict): The original model state dict + trtllm_conversion_dict (dict): The conversion dictionary mapping input model layer names to trtllm layer names + state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True + + Raises: + ValueError: In case the keys dont match to trtllm keys or if all model layers are not mapped to equivalent trtllm keys + + Returns: + dict: The model state dict with the key (i.e original model layer name) replaced by trtllm layer names + """ + for original_model_layer_name in list(model_state_dict.keys()): + if "_extra_state" in original_model_layer_name: + del model_state_dict[original_model_layer_name] + continue + + original_layer_name_without_number, layer_number = ( + TRTLLMLayers.return_layer_name_and_number(original_model_layer_name) + ) + if 'layers' in original_layer_name_without_number and state_dict_split_by_layer_numbers: + assert ( + layer_number is not None + ), f"Layer number is None for {original_model_layer_name} and state_dict_split_by_layer_numbers is set to True. Consider setting it False" + + if original_layer_name_without_number not in trtllm_conversion_dict: + raise ValueError( + f'Unable to rename key {original_layer_name_without_number}. Provide an appropriate mapping in the trtllm_conversion_dict when you initialize TRTLLMHelper' + ) + + trtllm_layer = trtllm_conversion_dict[original_layer_name_without_number] + assert isinstance( + trtllm_layer, TRTLLMLayers + ), f"{trtllm_layer} is not supported for conversion. Please use one of the TRTLLMLayerNames we provided in megatron/core/export/trtllm/trtllm_layer_names" + + value = model_state_dict.pop(original_model_layer_name) + + if layer_number is not None: + trtllm_layer_name_with_number = re.sub( + r'(?<=layers\.)', f'{layer_number}.', trtllm_layer.value + ) + model_state_dict[trtllm_layer_name_with_number] = value + else: + model_state_dict[trtllm_layer.value] = value + + return model_state_dict + + +# These layers are not associated within the transformer block. +# So they dont have a layer number (i.e independant of number of layers in the model) +NON_TRANSFORMER_LAYERS_NAMES = [ + TRTLLMLayers.vocab_embedding.value, + TRTLLMLayers.position_embedding.value, + TRTLLMLayers.lm_head.value, + TRTLLMLayers.final_layernorm_weight.value, + TRTLLMLayers.final_layernorm_bias.value, +] + + +def get_layer_name_without_prefix(layer: TRTLLMLayers) -> str: + """Get TRTLayer name without prefix + + Given a layer e.g TRTLLMLayers.attention_qkv_weight it returns 'attention.qkv.weight' + + Args: + layer (TRTLLMLayers): The TRTLLMLayer + + Returns: + str: The TRTLLMLayers suffix (i.e Removing transformer.layers. fromt he layer name) + """ + layer_name_without_prefix = layer.value.replace("transformer.layers.", "") + return layer_name_without_prefix diff --git a/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py b/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py b/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py new file mode 100644 index 0000000000..401988d787 --- /dev/null +++ b/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py @@ -0,0 +1,280 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from typing import Optional + +import torch +from tqdm import tqdm + +from megatron.core import parallel_state +from megatron.core.export.data_type import DataType +from megatron.core.export.trtllm.trtllm_layers import NON_TRANSFORMER_LAYERS_NAMES, TRTLLMLayers +from megatron.core.export.trtllm.trtllm_layers import get_layer_name_without_prefix as suffix +from megatron.core.tensor_parallel.utils import VocabUtility +from megatron.core.transformer.transformer_config import TransformerConfig + + +def str_dtype_to_torch(dtype: DataType): + """Get torch datatype from input datatype""" + from tensorrt_llm._utils import str_dtype_to_torch + + return str_dtype_to_torch(dtype.name) + + +# pylint: disable=line-too-long +class DistributedTRTLLMModelWeightsConverter: + """The TRTLLM Converter class used for GPU (on device) conversion + + This class is used to convert models sharded and on gpus. (It assumes that the model is already sharded appropriate to how you want to export it). (i.e) If you want to export to tp2pp2, then load the model in tp2pp2 setting and pass in their respective state dictionaries + """ + + def __init__( + self, + transformer_config: TransformerConfig, + dtype: DataType, + multi_query_mode: bool = False, + activation: str = "gelu", + scales: Optional[dict] = None, + ): + """Constructor for the TRTLLMModelWeightsConverterGPU class + + This class is responsible to convert the model weights to TRTLLM equivalent weights. + + Args: + transformer_config (TransformerConfig): The transformer config + dtype (DataType): The data type or model precision + multi_query_mode (bool, optional): Defaults to False. + activation (str, optional): Defaults to "gelu". + scales (dict, optional): Dictionary with fp8 scaling factors. + """ + if scales is None: + scales = {} + self.transformer_config = transformer_config + self.trtllm_model_weights = {} + self.storage_type = str_dtype_to_torch(dtype) + self.activation = activation + self.scales = scales + num_kv_heads = self.transformer_config.num_query_groups + if num_kv_heads == 0: + if multi_query_mode: + num_kv_heads = 1 + else: + num_kv_heads = self.transformer_config.num_attention_heads + self.num_kv_heads = num_kv_heads + + self.inference_pp_size = parallel_state.get_pipeline_model_parallel_world_size() + self.inference_tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.pp_rank = parallel_state.get_pipeline_model_parallel_rank() + self.tp_group = parallel_state.get_tensor_model_parallel_group() + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + + assert ( + vp_size is None or vp_size == 1 + ), "Virtual parallelism is not supported in GPU Converter. Gather the VP chunks and use PP config." + + def _add_to_trtllm_model_weights(self, val: torch.Tensor, layer_name: str): + assert torch.is_tensor(val), f"Expected a tensor for {layer_name} but got {type(val)}" + scale_key = '.'.join(layer_name.split('.')[:-1]) + '.weights_scaling_factor' + storage = self.storage_type + if scale_key in self.scales and layer_name.endswith("weight"): + storage = torch.float8_e4m3fn + val = val * self.scales[scale_key]['weight_multiplier'].to(val.device) + + val = val.to(storage) + val = val.detach().contiguous() + if val.ndim >= 2: + val = torch.transpose(val.reshape(val.shape[0], -1), 0, 1) + if layer_name not in self.trtllm_model_weights: + self.trtllm_model_weights[layer_name] = torch.empty( + val.size(), dtype=val.dtype, layout=val.layout, device="cpu", pin_memory=True + ) + self.trtllm_model_weights[layer_name].copy_(val, non_blocking=True) + + def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor): + """Convert Transformer layers to TRTLLM weights + + Transformer layers referes to layers within the transformber block. They have a layer number associated with them. Depending on the layer we either directly save it to trtllm_model_weights, or split it across some dimension and save the splits + + Args: + model_state_dict (dict): The input model state dictionary (All collected on CPU) + layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change + """ + if val.ndim == 2: + val = val.T + + if ( + layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight)) + ): + # Same as layernorm1p in NeMo + if ( + self.transformer_config.layernorm_zero_centered_gamma + and self.transformer_config.normalization == "LayerNorm" + and 'layernorm.weight' in layer_name + ): + val = val + 1.0 + + self._add_to_trtllm_model_weights(val=val, layer_name=layer_name) + + elif layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight)) or layer_name.endswith( + suffix(TRTLLMLayers.mlp_fc_bias) + ): + + split_gated_activation = self.activation in [ + "swiglu", + "geglu", + "fast-swiglu", + "fast-geglu", + ] + if split_gated_activation: + vals, gates = [[n] for n in torch.chunk(val, 2, axis=-1)] + gate_layer_name = layer_name.replace("fc", "gate") + self._add_to_trtllm_model_weights(val=gates[0], layer_name=gate_layer_name) + val = vals[0] + + self._add_to_trtllm_model_weights(val=val, layer_name=layer_name) + + elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_bias)): + qkv_hidden_dim = val.shape[0] + size_per_head = ( + qkv_hidden_dim + // (self.transformer_config.num_attention_heads + 2 * self.num_kv_heads) + * self.inference_tp_size + ) + q_num = self.transformer_config.num_attention_heads // self.num_kv_heads + + # We first concat all sub weights per tp rank together. + val = val.reshape(self.num_kv_heads // self.inference_tp_size, q_num + 2, size_per_head) + qkv = torch.split(val, [q_num, 1, 1], dim=1) + split_vals = torch.concatenate( + [qkv[0].reshape(-1), qkv[1].reshape(-1), qkv[2].reshape(-1)], dim=0 + ) + self._add_to_trtllm_model_weights(val=split_vals, layer_name=layer_name) + + # TODO : Should add a atten layer dimension "qkvqkv, qqkkvv etc to see how to reshape here" + elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_weight)): + hidden_dim = val.shape[0] + size_per_head = self.transformer_config.kv_channels + if size_per_head is None: + size_per_head = hidden_dim // self.transformer_config.num_attention_heads + q_num = self.transformer_config.num_attention_heads // self.num_kv_heads + + val = val.reshape( + hidden_dim, self.num_kv_heads // self.inference_tp_size, q_num + 2, size_per_head + ) + qkv = torch.split(val, [q_num, 1, 1], dim=2) + split_vals = torch.concatenate( + [ + qkv[0].reshape(hidden_dim, -1), + qkv[1].reshape(hidden_dim, -1), + qkv[2].reshape(hidden_dim, -1), + ], + dim=1, + ) + self._add_to_trtllm_model_weights(val=split_vals, layer_name=layer_name) + + else: + raise ValueError(f"{layer_name} cannot be handled by GPU converter") + + def _convert_non_transformer_layer(self, model_state_dict: dict, layer_name: str): + """Convert Non Transformer layers to TRTLLM weights + + Non transformer layers referes to layers that occur only once in the model (e.g Embedding , final output layer etc. ) They dont have any layer number associated with them. We remove this layer from the original state dict and cast it to storage type and convert to numpy and add it to trtllm_model_weights + + Args: + model_state_dict (dict): The input model state dictionary (All collected on CPU) + layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change + """ + if layer_name in model_state_dict: + val = model_state_dict.pop(layer_name) + self._add_to_trtllm_model_weights(val=val, layer_name=layer_name) + + # ----------------Convert Embeddings---------------- + def _get_remove_vocab_padding(self, layer_name, model_state_dict, tokenizer_vocab_size): + val = model_state_dict.get(layer_name, None) + if val is None: + return None + + if self.inference_tp_size > 1: # Gather padded tensor chunks + vocab_size_padded = val.shape[0] * self.inference_tp_size + vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + vocab_size_padded, self.tp_rank, self.inference_tp_size + ) + dim_size = list(val.size()) + dim_size[0] = vocab_size_padded + gathered_val = torch.zeros( + dim_size, dtype=val.dtype, device=torch.cuda.current_device() + ) + gathered_val[vocab_start_index:vocab_end_index] = val + torch.distributed.all_reduce(gathered_val, group=self.tp_group) + val = gathered_val + unpadded = val[:tokenizer_vocab_size] + if self.inference_tp_size > 1: # Split gathered val for val parallel embedding + vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + tokenizer_vocab_size, self.tp_rank, self.inference_tp_size + ) + unpadded = unpadded[vocab_start_index:vocab_end_index] + return unpadded.T # TRTLLM expects (vocab_size, hidden_size) so need extra transpose + + @torch.no_grad() + def convert( + self, model_state_dict: dict, trtllm_conversion_dict: dict, tokenizer_vocab_size: int + ): + """Convert model weights to trtllm model weights + + This method goes through each layer in the model state dict and converts to equivalent trtllm model weights. It also handles splitting across TP dimension , expert split etc. + + Args: + model_state_dict (dict): The full model state dict (all on CPU) + trtllm_conversion_dict (dict): The conversion dictionary used to convert model layer names to trtllm layer names + tokenizer_vocab_size (int): The vocab size of the tokenizer + """ + + # First step is to convert input model layer names to equivalent trtllm layer names + model_state_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names( + model_state_dict=model_state_dict, trtllm_conversion_dict=trtllm_conversion_dict + ) + + # Convert the non transformer layers + for layer_name in NON_TRANSFORMER_LAYERS_NAMES: + if layer_name not in model_state_dict: + continue + if ( + layer_name in TRTLLMLayers.vocab_embedding.value + or layer_name in TRTLLMLayers.lm_head.value + ): + # For embedding layers alone we do some pre processing + embed_val = self._get_remove_vocab_padding( + layer_name, model_state_dict, tokenizer_vocab_size + ) + model_state_dict[layer_name] = embed_val + # TODO : Check if this handling of position embedding is right. + if layer_name == TRTLLMLayers.position_embedding.value: + position_embedding = model_state_dict[layer_name] + req_position_embedding = position_embedding.chunk(self.inference_tp_size)[ + self.tp_rank + ] + model_state_dict[layer_name] = req_position_embedding.T + if layer_name == TRTLLMLayers.final_layernorm_weight.value: + # Same as layernorm1p in NeMo + if ( + self.transformer_config.layernorm_zero_centered_gamma + and self.transformer_config.normalization == "LayerNorm" + ): + model_state_dict[layer_name] = model_state_dict[layer_name] + 1.0 + self._convert_non_transformer_layer( + model_state_dict=model_state_dict, layer_name=layer_name + ) + + for layer_name, value in tqdm( + model_state_dict.items(), desc="Converting to TRTLLM Weights" + ): + self._convert_transformer_layer(layer_name, value) diff --git a/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py b/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py new file mode 100644 index 0000000000..7e669fc1c6 --- /dev/null +++ b/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py @@ -0,0 +1,471 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import re +from typing import Optional + +import torch +from tqdm import tqdm + +from megatron.core.export.data_type import DataType +from megatron.core.export.export_config import ExportConfig +from megatron.core.export.trtllm.trtllm_layers import NON_TRANSFORMER_LAYERS_NAMES, TRTLLMLayers +from megatron.core.export.trtllm.trtllm_layers import get_layer_name_without_prefix as suffix +from megatron.core.transformer.transformer_config import TransformerConfig + + +# pylint: disable=line-too-long +# TODO: Writing TRT imports this way so that it can be mocked in the test_trtllm_cpu_converter.py unit test +# TODO: Figure out how to patch it directly from the trtllm library +def pad_vocab_size(vocab_size: int, tp_size: int): + """Pad vocab size based on inference size""" + from tensorrt_llm._utils import pad_vocab_size + + return pad_vocab_size(vocab_size, tp_size) + + +def str_dtype_to_torch(dtype: DataType): + """Get torch datatype from input datatype""" + from tensorrt_llm._utils import str_dtype_to_torch + + return str_dtype_to_torch(dtype.name) + + +class SingleDeviceTRTLLMModelWeightsConverter: + """Class to convert Model weights to TRTLLM weights on CPU""" + + def __init__( + self, + export_config: ExportConfig, + transformer_config: TransformerConfig, + dtype: DataType, + multi_query_mode: bool = False, + activation: str = "gelu", + scales: Optional[dict] = None, + ): + """Constructor for the TRTLLMModelWeightsConverterCPU class + + This class is responsible to convert the model weights to TRTLLM equivalent weights and also split them for each GPU rank and return as a list. + + Args: + export_config (ExportConfig): The export config with inference tp size, pp size etc. + transformer_config (TransformerConfig): The transformer config + dtype (DataType): The data type or model precision + multi_query_mode (bool, optional): Defaults to False. + activation (str, optional): Defaults to "gelu". + scales (dict, optional): Dictionary with fp8 scaling factors. + """ + if scales is None: + scales = {} + + self.export_config = export_config + self.transformer_config = transformer_config + self.trtllm_model_weights = {} + self.storage_type = str_dtype_to_torch(dtype) + self.activation = activation + self.scales = scales + num_kv_heads = self.transformer_config.num_query_groups + if num_kv_heads == 0: + if multi_query_mode: + num_kv_heads = 1 + else: + num_kv_heads = self.transformer_config.num_attention_heads + self.num_kv_heads = num_kv_heads + + def _convert_non_transformer_layer(self, model_state_dict: dict, layer_name: str): + """Convert Non Transformer layers to TRTLLM weights + + Non transformer layers referes to layers that occur only once in the model (e.g Embedding , final output layer etc. ) They dont have any layer number associated with them. We remove this layer from the original state dict and cast it to storage type and convert to numpy and add it to trtllm_model_weights + + Args: + model_state_dict (dict): The input model state dictionary (All collected on CPU) + layer_name (str): The TRTLLM Layer name that we want to convert + """ + if layer_name in model_state_dict: + val = model_state_dict.pop(layer_name) + val = val.to(self.storage_type).detach().contiguous() + self.trtllm_model_weights[layer_name] = val + + def _cast_value(self, val: torch.Tensor, layer_name: str) -> torch.Tensor: + """Casts weights to the expected datatype. + When appropriate scaling factor is found inside self.scales, the weight gets scaled before the cast. + + Args: + val (torch.Tensor): Model weight + layer_name (str): Layer name, used for determining the scaling factor dictionary key + Returns: + torch.Tensor: The casted weight + """ + storage = self.storage_type + + scale_key = '.'.join(layer_name.split('.')[:-1]) + '.weights_scaling_factor' + if scale_key in self.scales and layer_name.endswith("weight"): + storage = torch.float8_e4m3fn + val = val * self.scales[scale_key]['weight_multiplier'].to(val.device) + + return val.to(storage) + + def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor): + """Convert Transformer layers to TRTLLM weights + + Transformer layers referes to layers within the transformber block. They have a layer number associated with them. Depending on the layer we either directly save it to trtllm_model_weights, or split it across some dimension and save the splits + + Args: + model_state_dict (dict): The input model state dictionary (All collected on CPU) + layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change + """ + + def _add_to_trtllm_model_weights(val: torch.Tensor, layer_name: str, split_type=None): + """Add the input weight to trtllm_model_weights + + Depending on split (Expert split/Tensor split/None) we split the input data and add accordingly + + Args: + val (torch.Tensor): The model weight to be added + layer_name (str): The TRTLLMlayername as a string + split_type (str, optional): The split type. Defaults to None. + """ + if split_type == 'expert_split': + for split_num, split_val in enumerate(val): + self.trtllm_model_weights[f'{layer_name}.{split_num}.bin'] = ( + self._cast_value(split_val, layer_name).detach().contiguous() + ) + elif split_type == 'tensor_split': + for split_num, split_val in enumerate(val): + if split_val.ndim >= 2: + split_val = torch.transpose(split_val.reshape(split_val.shape[0], -1), 1, 0) + + self.trtllm_model_weights[f'{layer_name}.{split_num}.bin'] = ( + self._cast_value(split_val, layer_name).detach().contiguous() + ) + else: + if val.ndim >= 2: + val = torch.transpose(val.reshape(val.shape[0], -1), 1, 0) + + self.trtllm_model_weights[layer_name] = ( + self._cast_value(val, layer_name).detach().contiguous() + ) + + if val.ndim == 2: + val = val.T + + if ( + layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_weight)) + or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_bias)) + or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight)) + ): + # Same as layernorm1p in NeMo + if ( + self.transformer_config.layernorm_zero_centered_gamma + and self.transformer_config.normalization == "LayerNorm" + and 'layernorm.weight' in layer_name + ): + val = val + 1.0 + + _add_to_trtllm_model_weights(val=val, layer_name=layer_name, split_type=None) + + elif layer_name.endswith( + suffix(TRTLLMLayers.attention_dense_weight) + ) or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight)): + split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=0) + _add_to_trtllm_model_weights( + val=split_vals, layer_name=layer_name, split_type='tensor_split' + ) + + elif layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight)) or layer_name.endswith( + suffix(TRTLLMLayers.mlp_fc_bias) + ): + split_gated_activation = self.activation in [ + "swiglu", + "geglu", + "fast-swiglu", + "fast-geglu", + ] + if split_gated_activation: + val, gate = torch.chunk(val, 2, axis=-1) + gate_layer_name = layer_name.replace("fc", "gate") + split_vals = torch.chunk(gate, self.export_config.inference_tp_size, axis=-1) + _add_to_trtllm_model_weights( + val=split_vals, layer_name=gate_layer_name, split_type='tensor_split' + ) + + split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=-1) + _add_to_trtllm_model_weights( + val=split_vals, layer_name=layer_name, split_type='tensor_split' + ) + + elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_bias)): + qkv_hidden_dim = val.shape[0] + size_per_head = qkv_hidden_dim // ( + self.transformer_config.num_attention_heads + 2 * self.num_kv_heads + ) + q_num = self.transformer_config.num_attention_heads // self.num_kv_heads + + # We first concat all sub weights per tp rank together. + val = val.reshape(self.num_kv_heads, q_num + 2, size_per_head) + + qkv = torch.split(val, [q_num, 1, 1], dim=1) + q_split = torch.chunk(qkv[0], self.export_config.inference_tp_size, axis=0) + k_split = torch.chunk(qkv[1], self.export_config.inference_tp_size, axis=0) + v_split = torch.chunk(qkv[2], self.export_config.inference_tp_size, axis=0) + + # Concatenate Q, K, and V together + split_vals = [ + torch.concatenate( + [q_split[i].reshape(-1), k_split[i].reshape(-1), v_split[i].reshape(-1)], dim=0 + ) + for i in range(self.export_config.inference_tp_size) + ] + _add_to_trtllm_model_weights( + val=split_vals, layer_name=layer_name, split_type='tensor_split' + ) + + # TODO : Should add a atten layer dimension "qkvqkv, qqkkvv etc to see how to reshape here" + elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_weight)): + hidden_dim = val.shape[0] + size_per_head = self.transformer_config.kv_channels + if size_per_head is None: + size_per_head = hidden_dim // self.transformer_config.num_attention_heads + q_num = self.transformer_config.num_attention_heads // self.num_kv_heads + + # When the merge factor exceeds 1, the 'vals' list will have multiple entries. + # Depending on the format, 'vals' can look like either [QQQQ..KV, QQQQ..KV, ...](for GQA) or [QKV, QKV, ...](for MHA). + # We first concat all sub weights per tp rank together. + val = val.reshape(hidden_dim, self.num_kv_heads, q_num + 2, size_per_head) + + # Split the QKV to separate variables. + qkv = torch.split(val, [q_num, 1, 1], dim=2) + + query_groups_shape = qkv[0].shape + if len(query_groups_shape) > 1: + if (query_groups_shape[1] % self.export_config.inference_tp_size) != 0: + raise Exception( + "Number of query groups of the models is {0}. Please select tensor parallelism size " + "that can split the number of query groups to equal number of query matrices in the " + "each GPU.".format(query_groups_shape[1]) + ) + + q_split = torch.chunk(qkv[0], self.export_config.inference_tp_size, axis=1) + k_split = torch.chunk(qkv[1], self.export_config.inference_tp_size, axis=1) + v_split = torch.chunk(qkv[2], self.export_config.inference_tp_size, axis=1) + + # Concatenate Q, K, and V together + split_vals = [ + torch.concatenate( + [ + q_split[i].reshape(hidden_dim, -1), + k_split[i].reshape(hidden_dim, -1), + v_split[i].reshape(hidden_dim, -1), + ], + dim=1, + ) + for i in range(self.export_config.inference_tp_size) + ] + _add_to_trtllm_model_weights( + val=split_vals, layer_name=layer_name, split_type='tensor_split' + ) + + elif layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight_mixture_of_experts)): + w1, w3 = torch.chunk(val, 2, axis=1) + # w1 splits + split_w1s = torch.chunk(w1, self.export_config.inference_tp_size, axis=1) + # w3 splits + split_w3s = torch.chunk(w3, self.export_config.inference_tp_size, axis=1) + + split_vals = [torch.concatenate(item, dim=1) for item in zip(split_w3s, split_w1s)] + layer_name = layer_name.replace(".expert", "") # Remove suffix .expert from key + _add_to_trtllm_model_weights( + val=split_vals, layer_name=layer_name, split_type='expert_split' + ) + + elif layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight_mixture_of_experts)): + split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=-1) + layer_name = layer_name.replace(".expert", "") # Remove suffix .expert from key + _add_to_trtllm_model_weights( + val=split_vals, layer_name=layer_name, split_type='expert_split' + ) + else: + raise ValueError(f"{layer_name} cannot be handled by converter") + + @torch.no_grad() + def convert( + self, model_state_dict: dict, trtllm_conversion_dict, state_dict_split_by_layer_numbers=True + ): + """Convert model weights to trtllm model weights + + This method goes through each layer in the model state dict and converts to equivalent trtllm model weights. It also handles splitting across TP dimension , expert split etc. + + Args: + model_state_dict (dict): The full model state dict (all on CPU) + trtllm_conversion_dict (dict): The conversion dictionary used to convert model layer names to trtllm layer names + state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True + """ + + # First step is to convert input model layer names to equivalent trtllm layer names + model_state_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names( + model_state_dict=model_state_dict, + trtllm_conversion_dict=trtllm_conversion_dict, + state_dict_split_by_layer_numbers=state_dict_split_by_layer_numbers, + ) + + # Convert the non transformer layers + for layer_name in NON_TRANSFORMER_LAYERS_NAMES: + # For vocab embedding layer alone we pad the weights to be divisible by inference tp size + if ( + layer_name == TRTLLMLayers.vocab_embedding.value + and self.export_config.use_parallel_embedding + ): + val = model_state_dict[TRTLLMLayers.vocab_embedding.value] + vocab_size = val.shape[0] + if vocab_size % self.export_config.inference_tp_size != 0: + vocab_size_padded = pad_vocab_size( + vocab_size, self.export_config.inference_tp_size + ) + pad_width = vocab_size_padded - vocab_size + val = torch.nn.functional.pad(val, (0, 0, 0, pad_width), value=0) + model_state_dict[layer_name] = val + if layer_name == TRTLLMLayers.final_layernorm_weight.value: + # Same as layernorm1p in NeMo + if ( + self.transformer_config.layernorm_zero_centered_gamma + and self.transformer_config.normalization == "LayerNorm" + ): + model_state_dict[layer_name] = model_state_dict[layer_name] + 1.0 + + self._convert_non_transformer_layer( + model_state_dict=model_state_dict, layer_name=layer_name + ) + + transformer_layers_dict = {} + # Convert the transformer layers + if state_dict_split_by_layer_numbers: + # Already model dict is split by layer numbers + transformer_layers_dict = model_state_dict + else: + # Here we split the model state dict into individual layers + for layer_name in list(model_state_dict.keys()): + value = model_state_dict.pop(layer_name) + for layer_number in range(self.transformer_config.num_layers): + # e.g transformer.layers.mlp.fc.bias => transformer.layers.2.mlp.fc.bias + layer_name_with_layer_number = re.sub( + r'(?<=layers\.)', f'{layer_number}.', layer_name + ) + transformer_layers_dict[layer_name_with_layer_number] = value[layer_number] + + for layer_name, value in tqdm( + transformer_layers_dict.items(), desc="Converting to TRTLLM Weights" + ): + self._convert_transformer_layer(layer_name, value) + + def get_padded_vocab_size(self) -> int: + """Return the paded vocab size + + We extract the lm head and vocab embedding and use that to determine padded_vocab_size + + Returns: + int: Padded vocab size + """ + lm_head_weight = self.trtllm_model_weights.get(TRTLLMLayers.lm_head.value, None) + vocab_size = self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value].shape[0] + vocab_size_padded = ( + vocab_size + if lm_head_weight is None + else pad_vocab_size(vocab_size, self.export_config.inference_tp_size) + ) + return vocab_size_padded + + def get_local_model_weights_per_gpu(self, mapping, trtllm_model_config: dict): + """Get the trtllm model weights split per gpu + + Given the trtllm mapping information (tp, pp rank etc) we split the model weights in a list, with each element of the list corresponding to the weights of each gpu rank + + Args: + mapping : The trtllm mapping information + trtllm_model_config (dict): The trtllm model config + """ + + def _split(torch_tensor, tp_size, idx, dim=0): + """Splits the np tensor v on dim and return the idx's slice.""" + if tp_size == 1: + return torch_tensor + if len(torch_tensor.shape) == 1: + return torch.chunk(torch_tensor, tp_size)[idx].contiguous() + else: + return torch.chunk(torch_tensor, tp_size, axis=dim)[idx].contiguous() + + pp_layer_range = mapping.pp_layers(self.transformer_config.num_layers) + + trtllm_model_weights_per_gpu = {} + for layer_name, value in self.trtllm_model_weights.items(): + if layer_name in NON_TRANSFORMER_LAYERS_NAMES: + continue + + # Happens in the case of TP split or expert split + if layer_name.endswith(".bin"): + if layer_name.endswith(f"{mapping.tp_rank}.bin"): + layer_name = layer_name.replace(f".{mapping.tp_rank}.bin", "") + else: + continue + + layer_num = int(layer_name.split(".")[2]) + if layer_num in pp_layer_range: + layer_name = layer_name.replace( + f"layers.{layer_num}", f"layers.{layer_num - pp_layer_range[0]}" + ) + else: + continue + if ( + hasattr(trtllm_model_config, 'new_decoder_architecture') + and trtllm_model_config.new_decoder_architecture + and "post_layernorm" in layer_name + ): + layer_name = layer_name.replace("post_layernorm", "mlp_layernorm") + + trtllm_model_weights_per_gpu[layer_name] = value + + if mapping.is_first_pp_rank(): + embedding_weight = ( + _split( + self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value], + mapping.tp_size, + mapping.tp_rank, + ) + if self.export_config.use_parallel_embedding + else self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value] + ) + + trtllm_model_weights_per_gpu[TRTLLMLayers.vocab_embedding.value] = embedding_weight + + pos_embedding_weight = self.trtllm_model_weights.get( + TRTLLMLayers.position_embedding.value + ) + if pos_embedding_weight is not None: + if self.export_config.use_parallel_embedding: + pos_embedding_weight = _split( + pos_embedding_weight, mapping.tp_size, mapping.tp_rank + ) + + trtllm_model_weights_per_gpu[TRTLLMLayers.position_embedding.value] = ( + pos_embedding_weight + ) + + if mapping.is_last_pp_rank(): + lm_head_weight = self.trtllm_model_weights.get(TRTLLMLayers.lm_head.value, None) + if lm_head_weight is not None: + trtllm_model_weights_per_gpu[TRTLLMLayers.lm_head.value] = _split( + lm_head_weight, mapping.tp_size, mapping.tp_rank + ) + + trtllm_model_weights_per_gpu[TRTLLMLayers.final_layernorm_weight.value] = ( + self.trtllm_model_weights[TRTLLMLayers.final_layernorm_weight.value] + ) + + ln_f_bias = self.trtllm_model_weights.get(TRTLLMLayers.final_layernorm_bias.value) + if ln_f_bias is not None: + trtllm_model_weights_per_gpu[TRTLLMLayers.final_layernorm_bias.value] = ln_f_bias + + return trtllm_model_weights_per_gpu diff --git a/megatron/data/__init__.py b/megatron/core/extensions/__init__.py similarity index 100% rename from megatron/data/__init__.py rename to megatron/core/extensions/__init__.py diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py new file mode 100644 index 0000000000..8a27b5bcd4 --- /dev/null +++ b/megatron/core/extensions/transformer_engine.py @@ -0,0 +1,1353 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import dataclasses +import io +import os +import pickle +import warnings +from typing import Any, Callable, Optional + +import torch +import transformer_engine as te +from packaging.version import Version as PkgVersion +from torch import Tensor +from torch.nn.parameter import Parameter + +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import ( + get_context_parallel_global_ranks, + get_context_parallel_group, + get_expert_data_parallel_rank, + get_expert_model_parallel_rank, + get_expert_model_parallel_world_size, + get_expert_tensor_parallel_group, + get_expert_tensor_parallel_rank, + get_expert_tensor_parallel_world_size, + get_hierarchical_context_parallel_groups, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name +from megatron.core.tensor_parallel.layers import ( + _initialize_affine_weight_cpu, + set_tensor_model_parallel_attributes, +) +from megatron.core.tensor_parallel.random import get_data_parallel_rng_tracker_name +from megatron.core.tensor_parallel.utils import divide +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint +from megatron.core.utils import get_te_version, is_te_min_version + + +def _get_extra_te_kwargs(config: TransformerConfig): + extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype} + + if is_te_min_version("0.12.0"): + if config.use_cpu_initialization: + extra_transformer_engine_kwargs["device"] = 'cpu' + else: + extra_transformer_engine_kwargs["device"] = torch.cuda.current_device() + return extra_transformer_engine_kwargs + + +def condition_init_method(config, init_method): + """Condition TE init_method on config.perform_initialization.""" + return init_method if config.perform_initialization else (lambda w: None) + + +class TENorm: + """ + A conditional wrapper to initialize an instance of Transformer-Engine's + `LayerNorm` or `RMSNorm` based on input + """ + + # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? + def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5): + if config.normalization == "LayerNorm": + instance = te.pytorch.LayerNorm( + hidden_size=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **_get_extra_te_kwargs(config), + ) + elif config.normalization == "RMSNorm": + assert hasattr( + te.pytorch, "RMSNorm" + ), "Transformer-Engine >= v0.11 required to use this feature" + instance = te.pytorch.RMSNorm( + hidden_size=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **_get_extra_te_kwargs(config), + ) + else: + raise Exception('Only LayerNorm and RMSNorm are curently supported') + + return instance + + +class TELinear(te.pytorch.Linear): + """ + Wrapper for the Transformer-Engine's `Linear` layer. + + Note that if Megatron's parallel_state has not been initialized + yet, the tp_group passed to TE will be None and must be set later + via set_tensor_parallel_group(). + + parallel_mode currently supports 3 different values: + - "column": Split the weight matrix along output dimension (used in TEColumnParallelLinear) + - "row": Split the weight matrix along input dimension (used in TERowParallelLinear) + - "duplicated": No tensor parallelism and weight is duplicated across TP ranks + - Note: For expert linear layers, we will disable communication logic here + as TP communication is handled in token_dispatcher. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + parallel_mode: Optional[str], + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + skip_weight_param_allocation: bool, + tp_comm_buffer_name: Optional[str] = None, + is_expert: bool = False, + ): + self.config = config + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache + if skip_weight_param_allocation: + raise ValueError( + 'Transformer Engine linear layers do not support skip_weight_param_allocation' + ) + + extra_kwargs = _get_extra_te_kwargs(config) + + if is_te_min_version("0.8.0"): + if self.config.tp_comm_overlap: + if is_te_min_version("1.5.0"): + # Use old overlap flags if they were supplied instead + extra_kwargs["ub_overlap_ag"] = ( + self.config.tp_comm_overlap_ag + if hasattr(self.config, "tp_comm_overlap_ag") + else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag + ) + extra_kwargs["ub_overlap_rs"] = ( + self.config.tp_comm_overlap_rs + if hasattr(self.config, "tp_comm_overlap_rs") + else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs + ) + # Disable ub overlap for experts. + if is_expert: + extra_kwargs["ub_overlap_ag"] = False + extra_kwargs["ub_overlap_rs"] = False + else: + extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag + extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag + extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs + extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs + # Disable ub overlap for experts. + if is_expert: + extra_kwargs["ub_split_ag"] = False + extra_kwargs["ub_atomic_gemm_ag"] = False + extra_kwargs["ub_split_rs"] = False + extra_kwargs["ub_atomic_gemm_rs"] = False + if is_te_min_version("1.0.0", check_equality=False): + assert ( + tp_comm_buffer_name is not None + ), "Buffer name should be set to configure communication overlap settings" + extra_kwargs["ub_name"] = tp_comm_buffer_name + + self.expert_parallel = self.config.expert_model_parallel_size > 1 + if is_expert: + rng_tracker_name = get_expert_parallel_rng_tracker_name() + else: + if parallel_mode == "duplicated": + rng_tracker_name = get_data_parallel_rng_tracker_name() + else: + rng_tracker_name = None + if is_te_min_version("1.7.0"): + extra_kwargs["rng_tracker_name"] = rng_tracker_name + + te_parallel_mode = parallel_mode + if parallel_mode == "duplicated": + # Handle non-parallel case + tp_group = None + tp_size = 1 + explicit_expert_comm = False + te_parallel_mode = None + else: + # Disable communications in TE when using TP or EP by + # making TE agnostic of model parallel. + if is_expert: + tp_group = get_expert_tensor_parallel_group(check_initialized=False) + tp_size = get_expert_tensor_parallel_world_size() + else: + tp_group = get_tensor_model_parallel_group(check_initialized=False) + tp_size = get_tensor_model_parallel_world_size() + explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) + + if explicit_expert_comm: + if parallel_mode == "column": + output_size = divide(output_size, tp_size) + elif parallel_mode == "row": + input_size = divide(input_size, tp_size) + te_parallel_mode = None + tp_size = 1 + tp_group = None + + super().__init__( + in_features=input_size, + out_features=output_size, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=tp_group, + tp_size=tp_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + init_method=condition_init_method(config, init_method), + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=te_parallel_mode, + **extra_kwargs, + ) + + for param in self.parameters(): + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, 'allreduce', not self.expert_parallel) + else: + # Reduce the gradient on DP group + setattr(param, 'allreduce', True) + if parallel_mode == "duplicated": + # Reduce the gradient further on the TP group since the weight is + # duplicated across TP ranks + setattr(param, 'sequence_parallel', self.config.sequence_parallel) + + def forward(self, x): + """Forward.""" + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) + out = super().forward(x, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Replicate cross TP/DP.""" + + # Provide the dist-ckpt support when TELinear is directly used + # It can only happen with duplicated parallel mode + assert ( + self.parallel_mode == None + ), "TELinear sharded_state_dict can only be used with duplicated parallel mode" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint(state_dict, prefix, None, sharded_offsets) + + +class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): + """ + Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines + layernorm and linear layers + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: TransformerConfig, + init_method: Callable, + gather_output: bool, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + skip_weight_param_allocation: bool = False, + tp_comm_buffer_name: Optional[str] = None, + ): + self.config = config + + if gather_output: + raise ValueError('Transformer Engine linear layers do not support gather_output = True') + + if is_expert: + raise ValueError('Transformer Engine linear layers do not yet support MoE') + + if skip_weight_param_allocation: + raise ValueError( + 'Transformer Engine linear layers do not support skip_weight_param_allocation' + ) + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache + extra_kwargs = _get_extra_te_kwargs(config) + + # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` + if is_te_min_version("0.11.0"): + extra_kwargs["normalization"] = self.config.normalization + elif self.config.normalization != "LayerNorm": + te_version = get_te_version() + raise ValueError( + f"Transformer Engine v{te_version} does not support {self.config.normalization}." + ) + + if is_te_min_version("0.8.0"): + if self.config.tp_comm_overlap: + extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad + extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad + if is_te_min_version("1.5.0", check_equality=False): + # Use old overlap flags if they were supplied instead + extra_kwargs["ub_overlap_ag"] = ( + self.config.tp_comm_overlap_ag + if hasattr(self.config, "tp_comm_overlap_ag") + else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag + ) + if is_te_min_version("1.6.0.dev0", check_equality=False): + extra_kwargs["ub_overlap_rs_dgrad"] = ( + self.config.tp_comm_overlap_rs_dgrad + if hasattr(self.config, "tp_comm_overlap_rs_dgrad") + else False + ) + if tp_comm_buffer_name == 'qkv' and self.config.tp_comm_overlap_disable_qkv: + extra_kwargs["ub_overlap_ag"] = False + extra_kwargs["ub_overlap_rs_dgrad"] = False + + if tp_comm_buffer_name == 'fc1' and self.config.tp_comm_overlap_disable_fc1: + extra_kwargs["ub_overlap_ag"] = False + extra_kwargs["ub_overlap_rs_dgrad"] = False + else: + extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag + extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag + if is_te_min_version("1.0.0", check_equality=False): + assert ( + tp_comm_buffer_name is not None + ), "Buffer name should be set to configure communication overlap settings" + extra_kwargs["ub_name"] = tp_comm_buffer_name + + super().__init__( + in_features=input_size, + out_features=output_size, + eps=self.config.layernorm_epsilon, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=get_tensor_model_parallel_group(check_initialized=False), + tp_size=self.config.tensor_model_parallel_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + init_method=( + condition_init_method(config, init_method) + if not config.use_cpu_initialization + else lambda w: None + ), + bias=bias, + return_bias=self.te_return_bias, + parallel_mode="column", + return_layernorm_output=False, + zero_centered_gamma=self.config.layernorm_zero_centered_gamma, + **extra_kwargs, + ) + + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + + if config.use_cpu_initialization: + output_size_per_partition = divide(output_size, world_size) + _ = _initialize_affine_weight_cpu( + self.weight, + output_size, + input_size, + output_size_per_partition, + 0, + init_method=condition_init_method(config, init_method), + stride=1, + return_master_weight=False, + rank=rank, + world_size=world_size, + skip_set_tensor_parallel_attributes=True, + ) + if bias: + self.bias = Parameter( + torch.empty(output_size_per_partition, dtype=config.params_dtype) + ) + set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, 'allreduce', True) + + def forward(self, x): + """Forward.""" + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) + out = super().forward(x, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets + ) + + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + + +class TEColumnParallelLinear(TELinear): + """ + Wrapper for the Transformer-Engine's `Linear` layer but specialized similar + to megatron's `ColumnParallelLinear` layer. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + gather_output: bool, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + skip_weight_param_allocation: bool = False, + tp_comm_buffer_name: Optional[str] = None, + ): + if gather_output: + raise ValueError('Transformer Engine linear layers do not support gather_output = True') + + super().__init__( + input_size=input_size, + output_size=output_size, + parallel_mode="column", + config=config, + init_method=( + condition_init_method(config, init_method) + if not config.use_cpu_initialization + else lambda w: None + ), + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + skip_weight_param_allocation=skip_weight_param_allocation, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + + if config.use_cpu_initialization: + if is_expert: + world_size = get_expert_tensor_parallel_world_size() + rank = get_expert_tensor_parallel_rank() + else: + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + output_size_per_partition = divide(output_size, world_size) + _ = _initialize_affine_weight_cpu( + self.weight, + output_size, + input_size, + output_size_per_partition, + 0, + init_method=condition_init_method(config, init_method), + stride=1, + return_master_weight=False, + rank=rank, + world_size=world_size, + skip_set_tensor_parallel_attributes=True, + ) + if bias: + self.bias = Parameter( + torch.empty(output_size_per_partition, dtype=config.params_dtype) + ) + set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, 'allreduce', True) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 0, bias sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets + ) + + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + + +class TERowParallelLinear(TELinear): + """ + Wrapper for the Transformer-Engine's `Linear` layer but specialized similar + to megatron's `RowParallelLinear` layer. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + input_is_parallel: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: Optional[str] = None, + ): + if not input_is_parallel: + raise ValueError( + "Transformer Engine linear layers do not support input_is_parallel = False" + ) + + super().__init__( + input_size=input_size, + output_size=output_size, + parallel_mode="row", + config=config, + init_method=( + condition_init_method(config, init_method) + if not config.use_cpu_initialization + else lambda w: None + ), + bias=bias, + skip_bias_add=skip_bias_add, + skip_weight_param_allocation=False, # We don't currently use this for row parallel layers # pylint: disable=line-too-long + is_expert=is_expert, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + if config.use_cpu_initialization: + if is_expert: + world_size = get_expert_tensor_parallel_world_size() + rank = get_expert_tensor_parallel_rank() + else: + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + input_size_per_partition = divide(input_size, world_size) + self.master_weight = _initialize_affine_weight_cpu( + self.weight, + output_size, + input_size, + input_size_per_partition, + 1, + init_method=condition_init_method(config, init_method), + stride=1, + return_master_weight=False, + params_dtype=config.params_dtype, + rank=rank, + world_size=world_size, + skip_set_tensor_parallel_attributes=True, + ) + if bias: + self.bias = Parameter(torch.empty(output_size, dtype=config.params_dtype)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, 'allreduce', True) + setattr(self.bias, 'sequence_parallel', config.sequence_parallel) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Sharding along axis 1, bias not sharded""" + state_dict = self.state_dict(prefix='', keep_vars=True) + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {'weight': 1}, sharded_offsets + ) + + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + + +class TEDotProductAttention(te.pytorch.DotProductAttention): + """ + Wrapper for the Transformer-Engine's `DotProductAttention` layer that also + has "flash attention" enabled. + + Note that if Megatron's parallel_state has not been initialized yet, the + tp_group and cp_group passed to TE will be None and must be set later + via set_tensor_parallel_group() and set_context_parallel_group(). + """ + + cp_stream: torch.cuda.Stream = None + + def __init__( + self, + config: TransformerConfig, + layer_number: int, + attn_mask_type: AttnMaskType, + attention_type: str, + attention_dropout: Optional[float] = None, + softmax_scale: Optional[float] = None, + k_channels: Optional[int] = None, + v_channels: Optional[int] = None, + cp_comm_type: str = "p2p", + ): + self.config = config + self.te_forward_mask_type = False + self.qkv_format: str = 'sbhd' + + if self.config.apply_query_key_layer_scaling != bool( + int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0')) + ): + raise ValueError( + f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} " + f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is " + f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support " + f"setting query key layer scaling via argument, so these two must match." + ) + + extra_kwargs: dict[str, Any] = {} + if is_te_min_version("0.11.0"): + extra_kwargs["num_gqa_groups"] = self.config.num_query_groups + elif self.config.num_query_groups != self.config.num_attention_heads: + raise ValueError( + f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, " + f"use a newer version of Transformer Engine. " + f"(num_query_groups ({self.config.num_query_groups}) != " + f"num_attention_heads ({self.config.num_attention_heads}))" + ) + + if is_te_min_version("0.10.0"): + extra_kwargs["attention_type"] = attention_type + # older version don't need attention_type + + if is_te_min_version("0.12.0", check_equality=False): + self.te_forward_mask_type = True + + # This check is important as CP config can be disabled while having a valid CP group + # Example - Disabling CP for encoder while a valid CP group exists for decoder + if self.config.context_parallel_size > 1: + assert is_te_min_version( + "1.0.0" + ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!" + if getattr(TEDotProductAttention, "cp_stream") is None: + TEDotProductAttention.cp_stream = torch.cuda.Stream() + extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False) + extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks( + check_initialized=False + ) + extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream + if is_te_min_version("1.10.0"): + if cp_comm_type is None: + extra_kwargs["cp_comm_type"] = "p2p" + elif cp_comm_type == "a2a+p2p": + assert is_te_min_version("1.12.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.12.0 to support" + "hierarchical cp commucation." + ) + extra_kwargs["cp_comm_type"] = "a2a+p2p" + extra_kwargs["cp_group"] = get_hierarchical_context_parallel_groups( + check_initialized=False + ) + else: + extra_kwargs["cp_comm_type"] = cp_comm_type + + if self.config.deterministic_mode: + if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0: + raise RuntimeError( + "deterministic_mode is on and we are using DotProductAttention from " + "Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. " + f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}." + ) + + if config.window_size is not None: + # Check version + assert is_te_min_version("1.2.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" + "sliding window attention." + ) + extra_kwargs['window_size'] = config.window_size + + if is_te_min_version("1.10.0"): + # TE 1.10.0 introduces the ability to set the different k and v channels + kv_channels = ( + (k_channels, v_channels) + if k_channels is not None and v_channels is not None + else self.config.kv_channels + ) + extra_kwargs['softmax_scale'] = softmax_scale + else: + kv_channels = self.config.kv_channels + + self.kept_packed_seq_params = set( + field.name for field in dataclasses.fields(PackedSeqParams) + ) + if get_te_version() < PkgVersion("1.3.0"): + # TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H + # copies (#555) + # These two arguments did not exist prior to 1.3.0 + self.kept_packed_seq_params.discard("max_seqlen_q") + self.kept_packed_seq_params.discard("max_seqlen_kv") + + if get_te_version() < PkgVersion("1.10.0"): + # TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted + # in each individual sequence in THD format dataset + # These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012) + self.kept_packed_seq_params.discard("cu_seqlens_q_padded") + self.kept_packed_seq_params.discard("cu_seqlens_kv_padded") + + super().__init__( + num_attention_heads=self.config.num_attention_heads, + kv_channels=kv_channels, + attention_dropout=( + self.config.attention_dropout if attention_dropout is None else attention_dropout + ), + attn_mask_type=attn_mask_type.name, + sequence_parallel=self.config.sequence_parallel, + tp_size=self.config.tensor_model_parallel_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + tp_group=get_tensor_model_parallel_group(check_initialized=False), + layer_number=layer_number, + **extra_kwargs, + ) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attention_mask: Tensor, + attn_mask_type: AttnMaskType, + attention_bias: Tensor = None, + packed_seq_params: PackedSeqParams = None, + ): + """Forward.""" + packed_seq_kwargs = ( + {key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params} + if packed_seq_params is not None + else {} + ) + # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set + # after init + if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False): + self.qkv_format = 'bshd' + + qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format) + + # WAR for peak memory usage. + # See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388 + if self.config.apply_rope_fusion and qkv_format == 'bshd': + query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)] + # In PyTorch, the following two tensors are in fact the same: + # Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1) + # Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1) + # Stride for a dimension that is 1 has no meaning, so tensors created two different ways + # can have same shape but different strides. + # We unify them to the first one to pass the stride check in TE + if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride(): + value = value.as_strided(value.shape, key.stride()) + + attention_bias_kwargs = {} + if attention_bias is not None: + assert is_te_min_version("1.2.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" + "`attention_bias`." + ) + attention_bias_kwargs = dict( + core_attention_bias_type='post_scale_bias', core_attention_bias=attention_bias + ) + + if self.te_forward_mask_type: + if qkv_format == 'thd' and is_te_min_version("1.7.0"): + # thd format uses flash attention with cuDNN kernel which requires is_padding=True, + # so the only acceptable mask types are `padding_causal` and `padding`. These do not + # necessarily indicate there are padded tokens in the sequence. + if attn_mask_type == AttnMaskType.causal: + attn_mask_type = AttnMaskType.padding_causal + elif attn_mask_type == AttnMaskType.no_mask: + attn_mask_type = AttnMaskType.padding + core_attn_out = super().forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type.name, + **attention_bias_kwargs, + **packed_seq_kwargs, + ) + else: + core_attn_out = super().forward( + query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs + ) + + if self.config.apply_rope_fusion and qkv_format == 'bshd': + return core_attn_out.transpose(0, 1) + else: + return core_attn_out + + +if is_te_min_version("1.9.0.dev0"): + + class TEGroupedLinear(te.pytorch.GroupedLinear): + """ + Wrapper for the Transformer-Engine's `GroupedLinear` layer. + + Note that if Megatron's parallel_state has not been initialized + yet, the tp_group passed to TE will be None and must be set later + via set_tensor_parallel_group(). + """ + + def __init__( + self, + num_gemms: int, + input_size: int, + output_size: int, + *, + parallel_mode: Optional[str], + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + is_expert: bool = False, + tp_comm_buffer_name: Optional[str] = None, + ): + self.config = config + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache + + extra_kwargs = _get_extra_te_kwargs(config) + extra_kwargs["ub_name"] = tp_comm_buffer_name + + self.expert_parallel = self.config.expert_model_parallel_size > 1 + if is_expert: + extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name() + + # The comms between TP and EP group is explicitly handled by MoE token dispatcher. + # So we disable comms by making TE agnostic of model parallel. + if is_expert: + tp_group = get_expert_tensor_parallel_group(check_initialized=False) + tp_size = get_expert_tensor_parallel_world_size() + else: + tp_group = get_tensor_model_parallel_group(check_initialized=False) + tp_size = get_tensor_model_parallel_world_size() + self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) + + if self.explicit_expert_comm: + if parallel_mode == "column": + output_size = divide(output_size, tp_size) + elif parallel_mode == "row": + input_size = divide(input_size, tp_size) + parallel_mode = None + tp_size = 1 + tp_group = None + + super().__init__( + num_gemms=num_gemms, + in_features=input_size, + out_features=output_size, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=tp_group, + tp_size=tp_size, + get_rng_state_tracker=( + get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None + ), + init_method=condition_init_method(config, init_method), + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=parallel_mode, + **extra_kwargs, + ) + + for param in self.parameters(): + setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) + + def merge_extra_states( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """ + Merge multiple "_extra_state" into one. + """ + self.init_fp8_metadata(num_gemms=self.num_gemms) + fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration + + try: + state_list = [ + state_dict.pop(f"{prefix}_extra_state{i}") for i in range(1, self.num_gemms) + ] + except KeyError: + # "_extra_state{i}" only exists for dist-ckpt. Return for torch native ckpt. + return + + if not fp8_checkpoint: + return + state_list = [state_dict.pop(f"{prefix}_extra_state")] + state_list + state_list = [self._decode_extra_state(state) for state in state_list] + extra_fp8_variables = state_list[0]['extra_fp8_variables'] + extra_fp8_variables['num_gemms'] = self.num_gemms + extra_state = { + "scale_fwd": torch.cat( + [state['scale_fwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "scale_inv_fwd": torch.cat( + [state['scale_inv_fwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "amax_history_fwd": torch.cat( + [state['amax_history_fwd'].view(-1, 1) for state in state_list], dim=1 + ).view(self.fp8_meta["recipe"].amax_history_len, -1), + "scale_bwd": torch.cat( + [state['scale_bwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "scale_inv_bwd": torch.cat( + [state['scale_inv_bwd'].view(-1, 1) for state in state_list], dim=1 + ).view(-1), + "amax_history_bwd": torch.cat( + [state['amax_history_bwd'].view(-1, 1) for state in state_list], dim=1 + ).view(self.fp8_meta["recipe"].amax_history_len, -1), + "extra_fp8_variables": extra_fp8_variables, + } + state_dict[f"{prefix}_extra_state"] = self._encode_extra_state(extra_state) + + self._register_load_state_dict_pre_hook(merge_extra_states, with_module=True) + + def forward(self, x, m_splits): + """Forward.""" + _is_first_microbatch = ( + None if self.disable_parameter_transpose_cache else self.is_first_microbatch + ) + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + def _encode_extra_state(self, state): + state_serialized = io.BytesIO() + torch.save(state, state_serialized) + return state_serialized + + def _decode_extra_state(self, state): + if isinstance(state, torch.Tensor): + return pickle.loads(state.detach().cpu().numpy().tobytes()) + elif isinstance(state, io.BytesIO): + state.seek(0) + return torch.load(state, map_location="cuda") + else: + raise RuntimeError("Unsupported checkpoint format.") + + def _split_extra_state(self, state): + fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration + + if not fp8_checkpoint: + return [state] * self.num_gemms + + state = self._decode_extra_state(state) + extra_states = [] + extra_fp8_variables = state['extra_fp8_variables'] + extra_fp8_variables['num_gemms'] = 1 + for gemm_idx in range(self.num_gemms): + tmp_state = { + "scale_fwd": state['scale_fwd'].view(3, -1)[:, gemm_idx], + "scale_inv_fwd": state['scale_inv_fwd'].view(3, -1)[:, gemm_idx], + "amax_history_fwd": state['amax_history_fwd'].view( + self.fp8_meta["recipe"].amax_history_len, 3, -1 + )[:, :, gemm_idx], + "scale_bwd": state['scale_bwd'].view(2, -1)[:, gemm_idx], + "scale_inv_bwd": state['scale_inv_bwd'].view(2, -1)[:, gemm_idx], + "amax_history_bwd": state['amax_history_bwd'].view( + self.fp8_meta["recipe"].amax_history_len, 2, -1 + )[:, :, gemm_idx], + "extra_fp8_variables": extra_fp8_variables, + } + extra_states.append(self._encode_extra_state(tmp_state)) + return extra_states + + def _sharded_state_dict_grouped( + self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None + ): + """ + prefix should be module_name to make keys identical to sequetial ones. + """ + sharded_state_dict = {} + full_state_dict = self.state_dict(prefix='', keep_vars=True) + num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms + local_expert_indices_offset = get_expert_model_parallel_rank() * self.num_gemms + ep_axis = len(sharded_offsets) + extra_states = self._split_extra_state(full_state_dict['_extra_state']) + for gemm_idx in range(self.num_gemms): + state_dict = { + f'{gemm_idx}.weight': full_state_dict[f'weight{gemm_idx}'], + f'{gemm_idx}._extra_state': extra_states[gemm_idx], + } + if self.use_bias: + state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}'] + sub_sd = make_sharded_tensors_for_checkpoint( + state_dict, + '', + tp_axis_map, + ( + *sharded_offsets, + (ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts), + ), + ) + # Remove expert layers indexing from sharded keys + replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix) + sharded_state_dict.update( + { + f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight'], + f'{prefix}_extra_state{"" if gemm_idx == 0 else gemm_idx}': sub_sd[ + f'{gemm_idx}._extra_state' + ], + } + ) + if self.use_bias: + sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias'] + # Adjust replica ids - replication along DP modulo EP + for k, sh_ten in sharded_state_dict.items(): + replica_id = sh_ten.replica_id + assert ( + len(replica_id) == 3 + ), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}' + sh_ten.replica_id = (*replica_id[:2], get_expert_data_parallel_rank()) + return sharded_state_dict + + class TEColumnParallelGroupedLinear(TEGroupedLinear): + """ + Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized + to column-parallel style. + """ + + def __init__( + self, + num_gemms: int, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: Optional[str] = None, + ): + + super().__init__( + num_gemms=num_gemms, + input_size=input_size, + output_size=output_size, + parallel_mode="column", + config=config, + init_method=condition_init_method(config, init_method), + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + For each gemm, sharding along axis 0, bias sharded. + Assume sharded_offsets[-1] is the expert parallel offset. + """ + tp_axis_map = {} + for gemm_idx in range(self.num_gemms): + tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0}) + return super()._sharded_state_dict_grouped( + tp_axis_map, prefix, sharded_offsets, metadata + ) + + class TERowParallelGroupedLinear(TEGroupedLinear): + """ + Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized + to row-parallel style. + """ + + def __init__( + self, + num_gemms: int, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: Optional[str] = None, + ): + + super().__init__( + num_gemms=num_gemms, + input_size=input_size, + output_size=output_size, + parallel_mode="row", + config=config, + init_method=condition_init_method(config, init_method), + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + tp_comm_buffer_name=tp_comm_buffer_name, + ) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """ + For each gemm, sharding along axis 1, bias not sharded. + Assume sharded_offsets[-1] is the expert parallel offset. + """ + tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)} + return super()._sharded_state_dict_grouped( + tp_axis_map, prefix, sharded_offsets, metadata + ) + +else: + + TEGroupedLinear = None # type: ignore[assignment, misc] + TEColumnParallelGroupedLinear = None # type: ignore[assignment, misc] + TERowParallelGroupedLinear = None # type: ignore[assignment, misc] + + +class TEDelayedScaling(te.common.recipe.DelayedScaling): + """ + Wrapper for the Transformer-Engine's `DelayedScaling` layer. + """ + + def __init__( + self, + config: ModelParallelConfig, + fp8_format: int, + override_linear_precision: tuple = (False, False, False), + ): + extra_kwargs = _get_extra_te_kwargs(config) + if is_te_min_version("1.6.0.dev0"): + extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention + extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention + if get_te_version() < PkgVersion("1.8.0"): + extra_kwargs["interval"] = config.fp8_interval + elif config.fp8_interval != 1: + warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.") + + super().__init__( + margin=config.fp8_margin, + fp8_format=fp8_format, + amax_compute_algo=config.fp8_amax_compute_algo, + amax_history_len=config.fp8_amax_history_len, + override_linear_precision=override_linear_precision, + **extra_kwargs, + ) + + +class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker): + """Wraps TransformerEngine's CudaRNGStatesTracker so that it is + interchangeable with Megatron's RNG tracker""" + + def __init__(self): + super().__init__() + self.reset() + + def is_initialized(self): + """Checks if the internal RNG state has been set wirth set_states().""" + return self._is_initialized + + def reset(self): + """Reset the internal RNG state.""" + super().reset() + self._is_initialized = False + + def set_states(self, states): + """Set the internal RNG state.""" + super().set_states(states) + self._is_initialized = True + + def add(self, name, seed): + """Track the rng state.""" + super().add(name, seed) + self._is_initialized = True + + +def te_checkpoint( + forward_func, + distribute_saved_activations, + get_rng_state_tracker, + tp_group, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, +): + """Checkpointing with Transformer-Engine.""" + from transformer_engine.pytorch.distributed import checkpoint + + if is_te_min_version("1.5.0"): + return checkpoint( + forward_func, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + distribute_saved_activations=distribute_saved_activations, + get_rng_state_tracker=get_rng_state_tracker, + tp_group=tp_group, + ) + else: + return checkpoint( + forward_func, + distribute_saved_activations, + get_rng_state_tracker, + tp_group, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + + +try: + + from transformer_engine.pytorch.attention import _SplitAlongDim + + SplitAlongDim = _SplitAlongDim.apply + +except ImportError: + + SplitAlongDim = None + +try: + + from transformer_engine.pytorch.cpu_offload import ( + get_cpu_offload_context as _get_cpu_offload_context, + ) + + def get_cpu_offload_context( + enabled, num_layers, model_layers, activation_offloading, weight_offloading + ): + """Get CPU offload context and sync function.""" + if is_te_min_version("1.10.0.dev0"): + context, sync_func = _get_cpu_offload_context( + enabled, num_layers, model_layers, activation_offloading, weight_offloading + ) + else: + context, sync_func = _get_cpu_offload_context( + enabled, num_layers, activation_offloading, weight_offloading + ) + + return context, sync_func + +except ImportError: + + get_cpu_offload_context = None # type: ignore[assignment, misc] + +try: + + from transformer_engine.pytorch.attention import FusedRoPEFunc + + def fused_apply_rotary_pos_emb( + t: torch.Tensor, freqs: torch.Tensor, transpose_output_memory: bool = False + ) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T in `sbhd` format.""" + if transpose_output_memory: + warnings.warn( + "transpose_output_memory is not supported by TE's fused RoPE and will be ignored." + ) + return FusedRoPEFunc.apply(t, freqs, "sbhd") + + def fused_apply_rotary_pos_emb_thd( + t: torch.Tensor, + cu_seqlens: torch.Tensor, + freqs: torch.Tensor, + cp_size: int = 1, + cp_rank: int = 0, + ) -> torch.Tensor: + """ + Apply rotary positional embedding to input tensor T in `thd` format with CP support. + """ + if is_te_min_version("1.11.0", check_equality=False): + return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens, cp_size, cp_rank) + else: + return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens) + +except ImportError: + + pass + +try: + + from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding # pylint: disable=unused-import + +except ImportError: + + Fp8Padding = None + Fp8Unpadding = None + +try: + + from transformer_engine.pytorch.permutation import ( + moe_permute, + moe_sort_chunks_by_index, + moe_unpermute, + ) + + fused_permute = moe_permute + fused_unpermute = moe_unpermute + fused_sort_chunks_by_index = moe_sort_chunks_by_index + +except ImportError: + + fused_permute = None + fused_unpermute = None + fused_sort_chunks_by_index = None diff --git a/megatron/core/fusions/fused_bias_dropout.py b/megatron/core/fusions/fused_bias_dropout.py index 14c1fe0d71..c7fa8419a0 100644 --- a/megatron/core/fusions/fused_bias_dropout.py +++ b/megatron/core/fusions/fused_bias_dropout.py @@ -3,6 +3,8 @@ import torch +from megatron.core.jit import jit_fuser + def _bias_dropout_add_func(x_with_bias, residual, prob, training): # type: (Tuple[Tensor, Optional[Tensor]], Tensor, float, bool) -> Tensor @@ -43,16 +45,16 @@ def _bias_dropout_add(x_with_bias, residual, prob): return _bias_dropout_add -@torch.jit.script +@jit_fuser def bias_dropout_add_fused_train( - x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float, + x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float ) -> torch.Tensor: return _bias_dropout_add_func(x_with_bias, residual, prob, True) -@torch.jit.script +@jit_fuser def bias_dropout_add_fused_inference( - x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float, + x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float ) -> torch.Tensor: return _bias_dropout_add_func(x_with_bias, residual, prob, False) diff --git a/megatron/core/fusions/fused_bias_geglu.py b/megatron/core/fusions/fused_bias_geglu.py new file mode 100644 index 0000000000..70ef348828 --- /dev/null +++ b/megatron/core/fusions/fused_bias_geglu.py @@ -0,0 +1,85 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch + +from megatron.core.jit import jit_fuser + +###### BIAS GELU FUSION/ NO AUTOGRAD ################ +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) + + +@jit_fuser +def geglu(y): + y_1, y_2 = torch.chunk(y, 2, -1) + return (y_1 * 0.5 * (1.0 + torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1)))) * y_2 + + +@jit_fuser +def bias_geglu(bias, y): + y = y + bias + return geglu(y) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@jit_fuser +def geglu_back(g, y): + y_1, y_2 = torch.chunk(y, 2, -1) + tanh_out = torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * y_1 * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * y_1 * y_1)) + 0.5 * ( + 1 + tanh_out + ) + return torch.cat(((g * y_2) * ff, g * (y_1 * 0.5 * (1.0 + tanh_out))), -1) + + +@jit_fuser +def bias_geglu_back(g, y, bias): + y = y + bias + return geglu_back(g, y) + + +class BiasGeGLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_geglu(input, bias) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_geglu_back(grad_output, input, bias) + return tmp, tmp + + +class GeGLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input): + ctx.save_for_backward(input) + return geglu(input) + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_tensors + tmp = geglu_back(grad_output, input[0]) + return tmp + + +def bias_geglu_impl(input, bias): + ori_shape = input.shape + assert len(ori_shape) in [2, 3] + input = input.view(-1, ori_shape[-1]) + if bias is not None: + output = BiasGeGLUFunction.apply(input, bias) + else: + output = GeGLUFunction.apply(input) + + return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) diff --git a/megatron/core/fusions/fused_bias_gelu.py b/megatron/core/fusions/fused_bias_gelu.py index 9c791c1807..8cc90f6174 100644 --- a/megatron/core/fusions/fused_bias_gelu.py +++ b/megatron/core/fusions/fused_bias_gelu.py @@ -2,7 +2,9 @@ import torch -###### BIAS GELU FUSION/ NO AUTOGRAD ################ +from megatron.core.jit import jit_fuser + +# BIAS GELU FUSION/ NO AUTOGRAD ################ # 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2) -> 0.70710678 # sqrt(2/pi) -> 0.79788456 @@ -11,7 +13,7 @@ # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) -@torch.jit.script +@jit_fuser def bias_gelu(bias, y): x = bias + y return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) @@ -20,7 +22,7 @@ def bias_gelu(bias, y): # gradient of tanh approximation of gelu # gradient of actual gelu is: # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script +@jit_fuser def bias_gelu_back(g, bias, y): x = bias + y tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) @@ -44,5 +46,10 @@ def backward(ctx, grad_output): tmp = bias_gelu_back(grad_output, bias, input) return tmp, tmp + # This is required to make Sphinx happy :-( + @classmethod + def apply(cls, *args, **kwargs): + return super().apply(*args, **kwargs) + bias_gelu_impl = GeLUFunction.apply diff --git a/megatron/core/fusions/fused_bias_swiglu.py b/megatron/core/fusions/fused_bias_swiglu.py new file mode 100644 index 0000000000..fd3ac3ec6f --- /dev/null +++ b/megatron/core/fusions/fused_bias_swiglu.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import torch +import torch.nn.functional as F + +from megatron.core.jit import jit_fuser + +###### BIAS SWIGLU FUSION/ NO AUTOGRAD ################ + + +@jit_fuser +def swiglu(y): + y_1, y_2 = torch.chunk(y, 2, -1) + return F.silu(y_1) * y_2 + + +@jit_fuser +def bias_swiglu(y, bias): + y = y + bias + return swiglu(y) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@jit_fuser +def swiglu_back(g, y): + y_1, y_2 = torch.chunk(y, 2, -1) + return torch.cat( + (g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2, g * F.silu(y_1)), -1 + ) + + +@jit_fuser +def bias_swiglu_back(g, y, bias): + y = y + bias + return swiglu_back(g, y) + + +class BiasSwiGLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias, fp8_input_store): + input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input + ctx.save_for_backward(input_for_backward, bias) + ctx.ori_input_dtype = input.dtype + ctx.fp8_input_store = fp8_input_store + return bias_swiglu(input, bias) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + tmp = bias_swiglu_back(grad_output, input, bias) + return tmp, tmp, None + + +class SwiGLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, fp8_input_store): + input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input + ctx.save_for_backward(input_for_backward) + ctx.ori_input_dtype = input.dtype + ctx.fp8_input_store = fp8_input_store + return swiglu(input) + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_tensors[0] + input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + tmp = swiglu_back(grad_output, input) + return tmp, None + + +def bias_swiglu_impl(input, bias, fp8_input_store=False): + ori_shape = input.shape + assert len(ori_shape) in [2, 3] + input = input.view(-1, ori_shape[-1]) + if bias is not None: + output = BiasSwiGLUFunction.apply(input, bias, fp8_input_store) + else: + output = SwiGLUFunction.apply(input, fp8_input_store) + + return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) + + +# bias_swiglu_impl = BiasSwiGLUFunction.apply +# swiglu_impl = SwiGLUFunction.apply diff --git a/megatron/core/fusions/fused_cross_entropy.py b/megatron/core/fusions/fused_cross_entropy.py new file mode 100644 index 0000000000..909cc403cf --- /dev/null +++ b/megatron/core/fusions/fused_cross_entropy.py @@ -0,0 +1,143 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from typing import Tuple + +import torch + +from megatron.core.jit import jit_fuser +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy +from megatron.core.tensor_parallel.utils import VocabUtility + + +@jit_fuser +def calculate_logits_max(vocab_parallel_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + vocab_parallel_logits, logits_max = VocabParallelCrossEntropy.calculate_logits_max( + vocab_parallel_logits + ) + + return vocab_parallel_logits, logits_max + + +@jit_fuser +def calculate_predicted_logits( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + logits_max: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + (target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits) = ( + VocabParallelCrossEntropy.calculate_predicted_logits( + vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index + ) + ) + + predicted_logits_sum_exp_logits = torch.cat((predicted_logits, sum_exp_logits)) + + return target_mask, masked_target_1d, predicted_logits_sum_exp_logits, exp_logits + + +@jit_fuser +def calculate_cross_entropy_loss( + exp_logits: torch.Tensor, predicted_logits_sum_exp_logits: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + + split_val = predicted_logits_sum_exp_logits.size()[0] // 2 + predicted_logits, sum_exp_logits = torch.split(predicted_logits_sum_exp_logits, split_val) + + exp_logits, loss = VocabParallelCrossEntropy.calculate_cross_entropy_loss( + exp_logits, predicted_logits, sum_exp_logits + ) + + return exp_logits, loss + + +@jit_fuser +def calculate_gradients( + softmax: torch.Tensor, + grad_output: torch.Tensor, + target_mask: torch.Tensor, + masked_target_1d: torch.Tensor, +) -> torch.Tensor: + + (grad_2d, arange_1d, softmax_update, grad_input) = ( + VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask) + ) + + grad_input = VocabParallelCrossEntropy.calculate_gradients( + grad_2d, arange_1d, masked_target_1d, softmax_update, grad_input, grad_output + ) + + grad_input = grad_input.to(torch.bfloat16) + + return grad_input + + +class _VocabParallelCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, vocab_parallel_logits, target): + + vocab_parallel_logits, logits_max = calculate_logits_max(vocab_parallel_logits) + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() + ) + + # Get the partition's vocab indices + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) + + (target_mask, masked_target_1d, predicted_logits_sum_exp_logits, exp_logits) = ( + calculate_predicted_logits( + vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index + ) + ) + + # All reduce is needed to get the chunks from other GPUs. + # In the fused case, tensors are batches to invoke a single + # AllReduce call + torch.distributed.all_reduce( + predicted_logits_sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_tensor_model_parallel_group(), + ) + + exp_logits, loss = calculate_cross_entropy_loss(exp_logits, predicted_logits_sum_exp_logits) + + # Store softmax, target-mask and masked-target for backward pass. + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + grad_input = calculate_gradients(softmax, grad_output, target_mask, masked_target_1d) + + return grad_input, None + + +def fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target): + """ + Performs cross entropy loss when logits are split across tensor parallel ranks + + Args: + vocab_parallel_logits: logits split across tensor parallel ranks + dimension is [sequence_length, batch_size, hidden_size] + + target: correct vocab ids of dimseion [sequence_length, micro_batch_size] + + """ + return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) diff --git a/megatron/core/fusions/fused_layer_norm.py b/megatron/core/fusions/fused_layer_norm.py index 8b308b9727..d02ae7aa4d 100644 --- a/megatron/core/fusions/fused_layer_norm.py +++ b/megatron/core/fusions/fused_layer_norm.py @@ -1,46 +1,71 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import importlib +import inspect import numbers import torch +from torch import Tensor from torch.nn import init from torch.nn.parameter import Parameter +from megatron.core.transformer import TransformerConfig from megatron.core.utils import make_viewless_tensor try: from apex.contrib.layer_norm.layer_norm import FastLayerNormFN HAVE_PERSIST_LAYER_NORM = True -except: +except ImportError: HAVE_PERSIST_LAYER_NORM = False try: from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction HAVE_FUSED_LAYER_NORM = True -except: +except ImportError: HAVE_FUSED_LAYER_NORM = False class FusedLayerNorm(torch.nn.Module): + """Layer Norm, fused into a single CUDA kernel. + + Args: + hidden_size (int): Transformer hidden dimension. + + eps (float): Epsilon added to denominator, for numerical stability. + + persist_layer_norm (bool): Use persistent fused layer norm kernel. + This kernel supports only a set of hidden sizes. Please + check persist_ln_hidden_sizes if your hidden size is supported. + + zero_centered_gamma (bool): Adjust LayerNorm weights such that they are + centered around zero. This improves numerical stability. + + config (TransformerConfig): Transformer config. Include to match custom + layer norm interfaces. + + normalization (str): Normalization type, used for Transformer Engine. + Must equal 'LayerNorm' here. + """ + def __init__( self, - hidden_size, - eps=1e-5, - persist_layer_norm=True, - sequence_parallel=False, - zero_centered_gamma=False, - normalization="LayerNorm", + config: TransformerConfig, + hidden_size: int, + eps: float = 1e-5, + persist_layer_norm: bool = True, + zero_centered_gamma: bool = False, + normalization: str = "LayerNorm", # included to match TE interface ): super().__init__() - self.zero_centered_gamma = zero_centered_gamma - self.normalization = normalization - assert normalization == "LayerNorm", '({}) is not supported in ' 'FusedLayerNorm'.format( - normalization - ) + self.config = config + + self.zero_centered_gamma = self.config.layernorm_zero_centered_gamma + assert ( + self.config.normalization == "LayerNorm" + ), f'({self.config.normalization}) is not supported in FusedLayerNorm' # List of hiddens sizes supported in the persistent layer norm kernel # If the hidden size is not supported, fall back to the non-persistent @@ -71,22 +96,24 @@ def __init__( 49152, 65536, ] + persist_layer_norm = self.config.persist_layer_norm if hidden_size not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM: persist_layer_norm = False if not persist_layer_norm and not HAVE_FUSED_LAYER_NORM: # TODO: Add pytorch only layer norm - raise ValueError(f'Apex must currently be installed to use megatron core.') + raise ValueError(f'Apex must be installed to use FusedLayerNorm.') if isinstance(hidden_size, numbers.Integral): hidden_size = (hidden_size,) self.hidden_size = torch.Size(hidden_size) self.eps = eps - self.weight = Parameter(torch.Tensor(*hidden_size)) - self.bias = Parameter(torch.Tensor(*hidden_size)) + # Parameters need to be initialized with torch.empty rather than torch.Tensor for correct device placement with nemo2. + self.weight = Parameter(torch.empty(*hidden_size)) + self.bias = Parameter(torch.empty(*hidden_size)) self.reset_parameters() self.persist_layer_norm = persist_layer_norm - self.sequence_parallel = sequence_parallel + self.sequence_parallel = self.config.sequence_parallel # set sequence parallelism flag on weight and bias parameters setattr(self.weight, 'sequence_parallel', self.sequence_parallel) @@ -101,12 +128,17 @@ def reset_parameters(self): init.ones_(self.weight) init.zeros_(self.bias) - def forward(self, input): + def forward(self, input: Tensor) -> Tensor: weight = self.weight + 1 if self.zero_centered_gamma else self.weight if self.persist_layer_norm: - output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) + if 'memory_efficient' in inspect.getfullargspec(FastLayerNormFN.forward).args: + output = FastLayerNormFN.apply( + input, weight, self.bias, self.eps, self.config.memory_efficient_layer_norm + ) + else: + output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) # Apex's fast layer norm function outputs a 'view' tensor (i.e., has # a populated '_base' field). This will result in schedule.py's @@ -117,8 +149,21 @@ def forward(self, input): ) else: - output = FusedLayerNormAffineFunction.apply( - input, weight, self.bias, self.hidden_size, self.eps - ) + if ( + 'memory_efficient' + in inspect.getfullargspec(FusedLayerNormAffineFunction.forward).args + ): + return FusedLayerNormAffineFunction.apply( + input, + weight, + self.bias, + self.hidden_size, + self.eps, + self.config.memory_efficient_layer_norm, + ) + else: + return FusedLayerNormAffineFunction.apply( + input, weight, self.bias, self.hidden_size, self.eps + ) return output diff --git a/megatron/core/fusions/fused_softmax.py b/megatron/core/fusions/fused_softmax.py index 56eb2e8011..c7bfbb768b 100644 --- a/megatron/core/fusions/fused_softmax.py +++ b/megatron/core/fusions/fused_softmax.py @@ -1,10 +1,12 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from typing import Optional import torch import torch.nn as nn from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.utils import get_default_causal_mask class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @@ -96,7 +98,7 @@ class FusedScaleMaskSoftmax(nn.Module): """ fused operation: scaling + mask + softmax - Arguments: + Args: input_in_fp16: flag to indicate if input in fp16 data format. input_in_bf16: flag to indicate if input in bf16 data format. attn_mask_type: attention mask type (pad or causal) @@ -131,7 +133,12 @@ def __init__( assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" - def forward(self, input, mask): + def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor]): + """Forward pass of softmax with masked input. + + In case attn_mask_type is causal the mask is generated and None can be passed. + A user-defined mask is only needed when attn_mask_type is not causal. + """ # [b, np, sq, sk] assert input.dim() == 4 @@ -186,6 +193,15 @@ def forward_torch_softmax(self, input, mask): if self.scale is not None: input = input * self.scale + + # Generate causal mask if not given + sq, sk = input.size(2), input.size(3) + if self.attn_mask_type == AttnMaskType.causal and mask is None and sq > 1: + # If sq == 1 then either KV cache is used or one-element context is passed + # so keeping mask=None in this case; subsequent code should handle it + assert sq == sk, "causal mask is only for self attention" + mask = get_default_causal_mask(sq) + mask_output = self.mask_func(input, mask) if mask is not None else input probs = torch.nn.Softmax(dim=-1)(mask_output) diff --git a/megatron/core/inference/__init__.py b/megatron/core/inference/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/inference/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/inference/async_stream.py b/megatron/core/inference/async_stream.py new file mode 100644 index 0000000000..b49d004441 --- /dev/null +++ b/megatron/core/inference/async_stream.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright 2025 The vLLM authors. +# +# This code was adopted from https://github.com/vllm-project/vllm/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +from typing import Any, AsyncGenerator, Callable, Optional, Type, Union + +from megatron.core.inference.inference_request import InferenceRequest + +STOP_ITERATION = Exception() + + +class AsyncStream: + """ + Class for encapsulating an asynchronous stream of InferenceRequest outputs. + + Adopted from https://github.com/vllm-project/vllm/blob/eb881ed006ca458b052905e33f0d16dbb428063a/vllm/v1/engine/async_stream.py # pylint: disable=line-too-long + """ + + def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: + self._request_id = request_id + self._cancel = cancel + self._queue: asyncio.Queue = asyncio.Queue() + self._finished = False + self._loop = asyncio.get_running_loop() + + def put(self, item: Union[InferenceRequest, Exception]) -> None: + """Adds a new value to the stream""" + if not self._finished: + self._loop.call_soon_threadsafe(self._queue.put_nowait, item) + + def finish(self, exception: Optional[Union[BaseException, Type[BaseException]]] = None) -> None: + """Completes the stream by adding a sentinel value""" + if not self._finished: + self._finished = True + self._loop.call_soon_threadsafe( + self._queue.put_nowait, + exception if self._is_raisable(exception) else STOP_ITERATION, + ) + + @property + def finished(self) -> bool: + """Whether the stream has finished""" + return self._finished + + async def generator(self) -> AsyncGenerator[InferenceRequest, None]: + """Creates an AsyncGenerator over the stream queue""" + try: + while True: + result = await self._queue.get() + if self._is_raisable(result): + if result == STOP_ITERATION: + return + raise result + yield result + except GeneratorExit: + self._cancel() + raise asyncio.CancelledError from None + + @staticmethod + def _is_raisable(value: Any): + return isinstance(value, BaseException) or ( + isinstance(value, type) and issubclass(value, BaseException) + ) diff --git a/megatron/core/inference/common_inference_params.py b/megatron/core/inference/common_inference_params.py new file mode 100644 index 0000000000..7955bb6fc1 --- /dev/null +++ b/megatron/core/inference/common_inference_params.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from megatron.core.inference.sampling_params import ( # noqa: F401 # pylint: disable=unused-import + SamplingParams as CommonInferenceParams, +) diff --git a/megatron/core/inference/communication_utils.py b/megatron/core/inference/communication_utils.py new file mode 100644 index 0000000000..8b2f5188f0 --- /dev/null +++ b/megatron/core/inference/communication_utils.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import torch + +from megatron.core import parallel_state + + +def _is_cuda(tensor): + """Check if a tensor is not none and is cuda.""" + assert tensor is not None + assert tensor.is_cuda + + +def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): + """Broadcast a tensor from last pipeline stage to all ranks.""" + + if parallel_state.is_pipeline_last_stage(): + assert size == list( + tensor.shape + ), f"Expected tensor of shape {size} but got {list(tensor.shape)}" + assert dtype == tensor.dtype, f"Expected tensor of type {dtype} but got {tensor.dtype}" + _is_cuda(tensor) + assert tensor.is_contiguous() + else: + tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) + # Get the group and corresponding source rank. + src = parallel_state.get_pipeline_model_parallel_last_rank() + group = parallel_state.get_pipeline_model_parallel_group() + torch.distributed.broadcast(tensor, src, group) + return tensor + + +def recv_from_prev_pipeline_rank_(recv_buffer=None): + """Receive from previous pipeline stage and update the + input buffer inplace.""" + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_buffer, parallel_state.get_pipeline_model_parallel_prev_rank() + ) + reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + + +def send_to_next_pipeline_rank(tensor=None): + """Send output to the next pipeline stage.""" + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, parallel_state.get_pipeline_model_parallel_next_rank() + ) + reqs = torch.distributed.batch_isend_irecv([send_next_op]) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() diff --git a/megatron/core/inference/engines/__init__.py b/megatron/core/inference/engines/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/inference/engines/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/inference/engines/abstract_engine.py b/megatron/core/inference/engines/abstract_engine.py new file mode 100644 index 0000000000..6893f6a905 --- /dev/null +++ b/megatron/core/inference/engines/abstract_engine.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from abc import ABC, abstractmethod +from typing import List + + +class AbstractEngine(ABC): + @staticmethod + @abstractmethod + def generate(self) -> dict: + """The abstract backend's generate function. + + To define a new backend, implement this and return the outputs as a dictionary. + + Returns: + dict: The output dictionary containing keys for `input_prompt`, `generated_text`, `generated_tokens`. + """ + pass diff --git a/megatron/core/inference/engines/mcore_engine.py b/megatron/core/inference/engines/mcore_engine.py new file mode 100644 index 0000000000..5f52c54124 --- /dev/null +++ b/megatron/core/inference/engines/mcore_engine.py @@ -0,0 +1,228 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import asyncio +import warnings +from collections import OrderedDict +from typing import AsyncGenerator, Dict, List, Optional, Union + +import torch + +from megatron.core.inference.async_stream import AsyncStream +from megatron.core.inference.engines.abstract_engine import AbstractEngine +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.scheduler import Scheduler +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) + + +class MCoreEngine(AbstractEngine): + """The Megatron core backend constructor + + This is the backend that does a simple forward pass on the model. + Supports any model that is callable (Accepts the inputs and outputs the tensor) + + Args: + text_generation_controller (TextGenerationController): A text generation + controller that will be used to define how to preprocess prompts, generate + outputs and detokenizer the output tokens. + max_batch_size (int, optional): The maximum number of requests to process at once. + Will be set from the InferenceWrapperConfig in `text_generation_controller` by + default. + random_seed (int, optional): Use a random seed if you want deterministic + results. Defaults to None. + """ + + def __init__( + self, + text_generation_controller: TextGenerationController, + max_batch_size: Optional[int] = None, + random_seed: Optional[int] = None, + ): + inference_wrapper_config = ( + text_generation_controller.inference_wrapped_model.inference_wrapper_config + ) + inference_max_batch_size = inference_wrapper_config.inference_max_requests + if max_batch_size is None: + max_batch_size = inference_max_batch_size + elif max_batch_size > inference_max_batch_size: + warnings.warn( + f"Engine `max_batch_size` ({max_batch_size}) > " + f"`inference_max_requests` in `inference_wrapper_config` " + f"({inference_max_batch_size}); setting `max_batch_size` to " + f"{inference_max_batch_size}", + UserWarning, + ) + max_batch_size = inference_max_batch_size + self.text_generation_controller = text_generation_controller + self.random_seed = random_seed + self.scheduler = Scheduler(max_batch_size=max_batch_size) + + def get_new_request_id(self) -> str: + """Gets a new request id from the scheduler""" + return self.scheduler.get_new_request_id() + + def add_request( + self, + prompt: Optional[str] = None, + add_BOS: bool = False, + encoder_prompt: Optional[str] = None, + inference_parameters: Optional[SamplingParams] = None, + streaming: bool = False, + inference_request: Optional[InferenceRequest] = None, + ) -> str: + """ + Adds a request to the scheduler and returns the request ID. + + Args: + prompt (str): A prompt string + add_BOS (bool): Whether to add BOS token to beginning of the prompt + encoder_prompt (str): The encoder prompt string + inference_parameters (SamplingParams): The inference parameters + streaming (bool): Whether to stream incremental outputs for this request + inference_request (InferenceRequest, optional): A fully constructed request. + Defaults to None. + + Returns: + The newly created request ID. + """ + assert ( + prompt is not None or inference_request is not None + ), f"At least one of `prompt` or `inference_request` must be specified" + + if inference_request is None: + prompt_tokens = self.text_generation_controller.tokenize_prompt(prompt, add_BOS) + else: + prompt_tokens = inference_request.prompt_tokens + + return self.scheduler.add_request( + prompt=prompt, + prompt_tokens=prompt_tokens, + encoder_prompt=encoder_prompt, + inference_parameters=inference_parameters, + streaming=streaming, + inference_request=inference_request, + ) + + def get_stream_generator( + self, request_id: str + ) -> Union[AsyncGenerator[InferenceRequest, None], None]: + """Returns the stream generator for the given request ID if it exists.""" + stream = self.scheduler.streams.get(request_id, None) + if stream is not None: + return stream.generator() + return None + + def generate( + self, + prompts: Optional[List[str]] = None, + add_BOS: bool = False, + encoder_prompts: Optional[List[str]] = None, + common_inference_params: Optional[SamplingParams] = None, + sampling_params: Optional[SamplingParams] = None, + inference_requests: Optional[List[InferenceRequest]] = None, + ) -> List[InferenceRequest]: + """The megatron core inference backend generate function + + This backend returns the output generations as a dictionary. + It returns the prompt tokens along with the generated tokens, the prompt + plus the generated string and the output log probabilities if requested + + Args: + prompts (List[str]): All the prompts as a list of strings + add_BOS (bool): Whether to add BOS token to beginning of prompts + encoder_prompts (List[dict]): All the encoder prompts as a list of strings + common_inference_params: Deprecated. Only used for backward compatibility with + MCore <= 0.9.0. Use `sampling_params` going forward. + sampling_params (SamplingParams): The request-level sampling parameters + inference_requests (List[InferenceRequest]): A pre-populated list of inference requests + + Returns: + List[InferenceRequest]: The output is list of inference requests containing the + generated tokens, texts and log probs if required + """ + # TODO :M core- get rng state tracker + + request_ids: List[str] = [] + + if self.random_seed: + torch.random.manual_seed(self.random_seed) + + if inference_requests is None: + assert prompts is not None + + if common_inference_params: + sampling_params = common_inference_params + + for i in range(len(prompts)): + prompt = prompts[i] + encoder_prompt = encoder_prompts[i] if encoder_prompts is not None else None + request_id = self.add_request( + prompt=prompt, + encoder_prompt=encoder_prompt, + inference_parameters=sampling_params, + ) + request_ids.append(request_id) + else: + for inference_request in inference_requests: + request_ids.append(inference_request.request_id) + self.scheduler.add_request(inference_request=inference_request) + + self.run_engine() + + result: List[InferenceRequest] = [ + self.scheduler.completed_request_pool[request_id] for request_id in request_ids + ] + return result + + def run_engine(self): + """Main functionality to run inference + + Runs the engine until there are no requests in the queue. + + Args: + dynamic_generation (bool, optional): Set this to True, if you want + to enable dynamic batching. Mainly used with an inference server. + Defaults to False. + """ + while self.scheduler.have_requests_pending(): + active_requests: Dict[str, InferenceRequest] = self.scheduler.active_request_pool.copy() + active_streams: Dict[str, AsyncStream] = OrderedDict() + for request_id in active_requests: + if (stream := self.scheduler.streams.get(request_id, None)) is not None: + assert isinstance(stream, AsyncStream), stream + active_streams[request_id] = stream + result_dict: Dict[str, InferenceRequest] = ( + self.text_generation_controller.generate_all_output_tokens_static_batch( + active_requests, active_streams + ) + ) + + self.scheduler.update_requests_pools(result_dict=result_dict) + + # TODO: Later for dynamic batching we will do something like this + """ + if dynamic_batching: + result_dict: Dict[ + str, InferenceRequest + ] = self.text_generation_controller.generate_output_tokens_one_step_dynamic_batch( + active_requests + ) + self.scheduler.update_requests_pools(result_dict=result_dict) + """ + + def _wrapped_run_engine(self, cuda_device): + """ + Explicitly sets the CUDA device before running the engine. + + This is to ensure that the CUDA device is correctly propagated when running + in a new thread context. + """ + torch.cuda.set_device(cuda_device) + self.run_engine() + + async def run_engine_async(self): + """Runs the engine asynchronously using asyncio""" + loop = asyncio.get_running_loop() + + await loop.run_in_executor(None, self._wrapped_run_engine, torch.cuda.current_device()) diff --git a/megatron/core/inference/inference_request.py b/megatron/core/inference/inference_request.py new file mode 100644 index 0000000000..d7ed1d801b --- /dev/null +++ b/megatron/core/inference/inference_request.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + +import torch + +from megatron.core.inference.sampling_params import SamplingParams + + +# class syntax +class Status(Enum): + """Enum for status""" + + WAITING_IN_QUEUE = 1 + ACTIVE_AND_GENERATING_TOKENS = 2 + ACTIVE_BUT_NOT_GENERATING_TOKENS = 3 + COMPLETED = 4 + + +@dataclass(kw_only=True) +class InferenceRequest: + """Class for one inference request + + Containing relevant data for an inference request + + """ + + request_id: str + prompt: str + inference_parameters: Optional[SamplingParams] = None + prompt_tokens: Optional[List[int]] = None + arrival_time: Optional[float] = None + status: Optional[Status] = None + encoder_prompt: Optional[str] = None + generated_text: Optional[str] = None + segments: Optional[List[str]] = None + generated_segments: Optional[List[str]] = None + generated_sequence_lengths: Optional[List[int]] = None + generated_tokens: Optional[torch.Tensor] = None + generated_log_probs: Optional[torch.Tensor] = None + generated_length: Optional[int] = None + + +@dataclass(kw_only=True) +class VLMInferenceRequest(InferenceRequest): + """Class for a VLM inference request""" + + num_img_embeddings_per_tile: int + imgs: torch.Tensor + num_tiles: torch.Tensor + decoder_seq_length: int diff --git a/megatron/core/inference/model_inference_wrappers/__init__.py b/megatron/core/inference/model_inference_wrappers/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/inference/model_inference_wrappers/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py new file mode 100644 index 0000000000..071ae0388d --- /dev/null +++ b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py @@ -0,0 +1,315 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import abc +import math +from typing import Any, Dict, Iterable, Optional, Union + +import torch + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.inference.communication_utils import ( + recv_from_prev_pipeline_rank_, + send_to_next_pipeline_rank, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.inference_params import InferenceParams +from megatron.core.models.gpt.gpt_model import GPTModel + + +# pylint: disable=line-too-long +class AbstractModelInferenceWrapper(abc.ABC): + """Abstract inference wrapper + + Extend this to create a version for your model. + """ + + def __init__( + self, + model: Union['LegacyGPTModel', GPTModel], # type: ignore[name-defined] + inference_wrapper_config: InferenceWrapperConfig, + ): + """Constructor for the model inference wrapper + + The wrapper prepares the model for inference, provides the required input data and runs the forward pass. + + Args: + model (Union[GPTModel, LegacyGPTModel]): The actual GPT model (MCore or MLM) + inference_wrapper_config (InferenceWrapperConfig): Has info like hidden size, vocab size etc. + """ + assert not isinstance( + model, Iterable + ), 'interleaving schedule is not supported for inference' + self.model = model + self.inference_wrapper_config = inference_wrapper_config + self.pipeline_communication_dtype = ( + torch.float + if self.inference_wrapper_config.fp32_residual_connection + else self.inference_wrapper_config.params_dtype + ) + + max_batch_size = self.inference_wrapper_config.inference_max_requests + max_sequence_length = self.inference_wrapper_config.inference_max_seq_length + self.inference_params = InferenceParams(max_batch_size, max_sequence_length) + + def prep_model_for_inference(self, prompts_tokens: torch.Tensor): + """A utility function for preparing model for inference + + The function gets called once before the auto regressive inference loop. + It puts the model in eval mode. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + + """ + self.model.eval() + + # For TP only model both is_pp_first_stage and _is_pp_last_stage returns True + self.model_is_pipeline_parallel = not ( + parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() + ) + + self.inference_params.reset() + + @abc.abstractmethod + def prep_inference_input(self, prompt_tokens) -> Dict[str, Any]: + """Prepares the inference input data. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + + Returns: + A dict with all the inference input needed for the batch. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_batch_for_context_window(self, *args, **kwargs) -> Dict[str, Any]: + """Returns the input data for inference + + This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference. + + """ + raise NotImplementedError() + + def _forward(self, inference_input): + """Runs a forward pass of the model. + + Args: + inference_input(Dict[str, Any]): The input data. + inference_params(InferenceParams): The inference parameters. + + Returns: + The model output logits. + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + attention_mask = inference_input["attention_mask"] + return self.model( + tokens, position_ids, attention_mask, inference_params=self.inference_params + ) + + def _get_batch_size_and_seq_len( + self, tokens: torch.Tensor, recv_buffer_seq_len: Optional[int] = None + ): + """ + Returns the batch size and sequence length based on the tokens tensor and recv_buffer_seq_len. + + Args: + tokens (torch.Tensor): The input tensor of shape (batch_size, seq_len). + recv_buffer_seq_len (int, optional): An optional recv buffer sequence length. + + Returns: + tuple: A tuple (batch_size, seq_len), where batch_size is the first dimension of tokens + and seq_len is either the second dimension or recv_buffer_seq_len. + """ + batch_size = tokens.shape[0] + seq_len = recv_buffer_seq_len if recv_buffer_seq_len is not None else tokens.shape[1] + return batch_size, seq_len + + def _allocate_recv_buffer(self, batch_size, seq_len): + """Receive happens between the layers with size [seq_len, batch_size, hidden_size].""" + recv_size = (seq_len, batch_size, self.inference_wrapper_config.hidden_size) + return torch.empty( + recv_size, dtype=self.pipeline_communication_dtype, device=torch.cuda.current_device() + ) + + def forward_pass_without_pipeline_parallel( + self, inference_input: Dict[str, Any] + ) -> torch.Tensor: + """Utility to carry out simple forward pass for TP or no model parallel models + + Runs a very simple forward pass for model. Used in the case of models without any parallelism or only tensor parallelism. + + Args: + inference_input (Dict[str, Any]): A dict containg the inputs for the gpt model [tokens, position ids, attention mask] + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] + """ + tokens = inference_input["tokens"] + logits = self._forward(inference_input) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + self.inference_params.sequence_len_offset += tokens.size(1) + + return logits + + def forward_pass_with_pipeline_parallel_small_input_batch( + self, inference_input: Dict[str, Any], recv_buffer_seq_len: Optional[int] = None + ) -> torch.Tensor: + """Utility to carry out forward pass for PP models with very small inputs + + If a model is pipeline parallel, yet, the input global batch is very small, we compute a foward pass on the entire global batch, rather than splitting it up into micro batches and doing something more complex as in the forward_pass_with_pipeline_parallel_large_input_batch method + + Args: + inference_input (Dict[str, Any]): A dict containing the inputs for the gpt model [tokens, position ids, attention mask] + recv_buffer_seq_len (int): An optional sequence length for the pipeline parallel recv buffer. + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + attention_mask = inference_input["attention_mask"] + batch_size, seq_len = self._get_batch_size_and_seq_len(tokens, recv_buffer_seq_len) + recv_buffer = None + if not parallel_state.is_pipeline_first_stage(): + recv_buffer = self._allocate_recv_buffer(batch_size, seq_len) + recv_from_prev_pipeline_rank_(recv_buffer) + + self.model.set_input_tensor(recv_buffer) + output_tensor = self._forward(inference_input) + + if not parallel_state.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor.type(dtype=self.pipeline_communication_dtype)) + + self.inference_params.sequence_len_offset += seq_len + + logits = None + if parallel_state.is_pipeline_last_stage(): + logits = output_tensor + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + # Explicitly cast logits to expected dtype + logits = logits.to(self.inference_wrapper_config.params_dtype) + + return logits + + def forward_pass_with_pipeline_parallel_large_input_batch( + self, inference_input: Dict[str, Any], recv_buffer_seq_len=None + ) -> torch.Tensor: + """Utility to carry out forward pass PP models. + + Runs the forward pass for models which are pipeline parallel. + This is more complex than forward_pass_with_pipeline_parallel_small_input_batch because + this splits the global batch into small micro batches and runs them through the model. + + Args: + inference_input (Dict[str, Any]): A dict containg the inputs for the gpt model [tokens, position ids, attention mask] + recv_buffer_seq_len (int): An optional sequence length for the pipeline parallel recv buffer. + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + attention_mask = inference_input["attention_mask"] + micro_batch_size = max( + 1, + self.inference_wrapper_config.inference_batch_times_seqlen_threshold // tokens.size(1), + ) + batch_size, seq_len = self._get_batch_size_and_seq_len(tokens, recv_buffer_seq_len) + # Round up to account for the last partial micro batch if present + num_micro_batches = math.ceil(batch_size / micro_batch_size) + + logits = None + # Preallocate memory for output logits. + if parallel_state.is_pipeline_last_stage(): + logits = torch.empty( + (batch_size, seq_len, self.inference_wrapper_config.padded_vocab_size), + dtype=self.pipeline_communication_dtype, + device=torch.cuda.current_device(), + ) + + recv_buffer = None + if not parallel_state.is_pipeline_first_stage(): + recv_buffer = self._allocate_recv_buffer(micro_batch_size, seq_len) + for micro_batch_index in range(num_micro_batches): + start = micro_batch_index * micro_batch_size + end = min(start + micro_batch_size, batch_size) + tokens2use = tokens[start:end, ...] + position_ids2use = position_ids[start:end, ...] + current_micro_batch_size = end - start + + # Need to change recv buffer shape for the last partial microbatch (if exists) + if current_micro_batch_size != micro_batch_size: + recv_buffer = self._allocate_recv_buffer(current_micro_batch_size, seq_len) + + if not parallel_state.is_pipeline_first_stage(): + recv_from_prev_pipeline_rank_(recv_buffer) + + self.model.set_input_tensor(recv_buffer) + output_tensor = self._forward( + { + "tokens": tokens2use, + "position_ids": position_ids2use, + "attention_mask": attention_mask, + } + ) + + if not parallel_state.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor) + + self.inference_params.batch_size_offset += current_micro_batch_size + + if parallel_state.is_pipeline_last_stage(): + output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region( + output_tensor + ) + assert logits is not None + logits[start:end, ...] = output_tensor + + # Explicitly cast logits to expected dtype + logits = logits.to(self.inference_wrapper_config.params_dtype) + + # Once done with all micro batches, we reset batch size offset and seq len offset + self.inference_params.sequence_len_offset += seq_len + self.inference_params.batch_size_offset = 0 + + # NOTE: Only returns the logits on the last pipeline stage + return logits + + def run_one_forward_step( + self, inference_input: Dict[str, Any], recv_buffer_seq_len: Optional[int] = None + ) -> torch.Tensor: + """The forward pass of the model for inference + + Appropriate utility is called for the forward pass depending on the type of model parallelism used + + Args: + inference_input (Dict[str, Any]): A dict containg the inputs for the gpt model [tokens, position ids, attention mask] + recv_buffer_seq_len (int): An optional sequence length for the pipeline parallel recv buffer. + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]. The logits are returned only in the last pipeline stage for PP models. + """ + if self.model_is_pipeline_parallel: + tokens = inference_input["tokens"] + current_batch_size, seq_len = self._get_batch_size_and_seq_len( + tokens, recv_buffer_seq_len + ) + # If input batch is large, we need to split into micro batches and run the forward pass + if ( + current_batch_size * seq_len + > self.inference_wrapper_config.inference_batch_times_seqlen_threshold + ): + return self.forward_pass_with_pipeline_parallel_large_input_batch( + inference_input, recv_buffer_seq_len + ) + else: + # If input batch is very small we can do a simple forward pass on the entire global batch + return self.forward_pass_with_pipeline_parallel_small_input_batch( + inference_input, recv_buffer_seq_len + ) + else: + return self.forward_pass_without_pipeline_parallel(inference_input) diff --git a/megatron/core/inference/model_inference_wrappers/gpt/__init__.py b/megatron/core/inference/model_inference_wrappers/gpt/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/inference/model_inference_wrappers/gpt/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py new file mode 100644 index 0000000000..e5a19bbfde --- /dev/null +++ b/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py @@ -0,0 +1,102 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from typing import Any, Dict, Tuple + +import torch + +from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( + AbstractModelInferenceWrapper, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.models.gpt import GPTModel + + +# pylint: disable=line-too-long +class GPTInferenceWrapper(AbstractModelInferenceWrapper): + """Inference wrapper for GPT model""" + + def __init__(self, model: GPTModel, inference_wrapper_config: InferenceWrapperConfig): + """Constructor for the model inference wrapper + + The wrapper prepares the model for inference, provides the required input data, and runs the forward pass + + Args: + model (GPTModel): The GPT model (MCore or legacy) + inference_wrapper_config (InferenceWrapperConfig): Has info like hidden size, vocab size etc + """ + super().__init__(model, inference_wrapper_config) + + def prep_inference_input(self, prompts_tokens: torch.Tensor) -> Dict[str, Any]: + """Prepares the inference input data. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + + Returns: + A dict with all the inference input needed for the batch. + """ + attention_mask, position_ids = self._build_attention_mask_and_position_ids(prompts_tokens) + return { + "tokens": prompts_tokens, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + + def _build_attention_mask_and_position_ids( + self, prompts_tokens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Builds the full attention mask and position ids for the input tokens + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The attention mask of shape [1, 1, max_seq_len, max_seq_len] and position ids of shape [batch_size, max_seq_len] + """ + seq_length = prompts_tokens.size(1) + attention_mask = torch.tril( + torch.ones((1, seq_length, seq_length), device=prompts_tokens.device) + ).view(1, 1, seq_length, seq_length) + # Convert to boolean + attention_mask = attention_mask < 0.5 + + position_ids = ( + torch.arange(seq_length, dtype=torch.long, device=prompts_tokens.device) + .unsqueeze(0) + .expand_as(prompts_tokens) + ) + + return attention_mask, position_ids + + def get_batch_for_context_window( + self, + inference_input: Dict[str, Any], + context_start_position: int, + context_end_position: int, + ) -> Dict[str, Any]: + """Returns the inference data given context window + + This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data. + + Args: + inference_input (Dict[str, Any]): The inference input for the batch. + context_start_position (int): Start of the context window. During the first inference step it is mostly 0 + context_end_position (int): End of the context window. During the last inference step it will mostly be the max generated sequence length. + + Returns: + Dict[str, Any]: A dict of inputs that will be used by your model in the forward step + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + attention_mask = inference_input["attention_mask"] + tokens2use = tokens[:, context_start_position:context_end_position] + positions2use = position_ids[:, context_start_position:context_end_position] + attention_mask2use = attention_mask[ + ..., context_start_position:context_end_position, :context_end_position + ] + return { + "tokens": tokens2use, + "position_ids": positions2use, + "attention_mask": attention_mask2use, + } diff --git a/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py b/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py new file mode 100644 index 0000000000..5c88d1ba48 --- /dev/null +++ b/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass + +import torch + + +@dataclass +class InferenceWrapperConfig: + """Config for the model inference wrapper + + NOTE : All the arguments here are obtained from arguments.py file + """ + + hidden_size: int + """Receive happens between the layers during PP with size [seq_len, batch_size, hidden_size]""" + + params_dtype: torch.dtype + """Can be torch.float or torch.half if --fp16 is used, or torch.bfloat16 if --bf16 is used""" + + inference_batch_times_seqlen_threshold: int + """if (batch-size * sequence-length) is smaller than this threshold then we will not pipeline + the batch.""" + + padded_vocab_size: int + """The final padded vocab size (Padded to make it divisible by + --make-vocab-size-divisible-by value)""" + + inference_max_requests: int = 8 + """ Maximum number of requests for inference (prefill & decode). Necessary for CUDA graphs. """ + + inference_max_seq_length: int = 2560 + """ Maximum sequence length for inference (prefill & decode). Necessary for CUDA graphs. """ + + fp32_residual_connection: bool = False + """Move residual connections to fp32. Obtained from arguments.py""" + + def add_attributes(self, attribute_value_pair: dict): + """Utility to add more attributes to inference params + + Use this method to pass in a custom dictionary to add more configs to the instance created. + Use as follows: + c = InferenceWrapperConfig + c.add_attributes({'precision':'fp32'}) + + Args: + attribute_value_pair (dict): A dictionary containing attributes as the key names and + corresponding values. + """ + for key, value in attribute_value_pair.items(): + setattr(self, key, value) diff --git a/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py new file mode 100644 index 0000000000..55ff3f3572 --- /dev/null +++ b/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py @@ -0,0 +1,208 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from typing import Any, Dict + +import torch + +from megatron.core import parallel_state +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( + GPTInferenceWrapper, +) +from megatron.core.inference_params import InferenceParams + + +# pylint: disable=line-too-long +class VLMInferenceWrapper(GPTInferenceWrapper): + """Inference wrapper for VLMs""" + + def prep_model_for_inference(self, prompts_tokens: torch.Tensor): + """A utility function for preparing model for inference + + The function gets called once before the auto regressive inference loop. + It puts the model in eval mode. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + + """ + super().prep_model_for_inference(prompts_tokens) + + # For TP only model both is_pp_first_stage and _is_pp_last_stage returns True + self.model_is_pipeline_parallel = not ( + parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() + ) + + self._recv_only_vision_embeds = False + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + # Checks if the previous stage only has a vision encoder, and that the current stage + # has part of the LM decoder. In this case, the current stage should only receive + # vision embeddings. + if pp_rank > 0: + self._recv_only_vision_embeds = ( + parallel_state.is_inside_encoder(pp_rank - 1) + and (not parallel_state.is_inside_decoder(pp_rank - 1)) + and parallel_state.is_inside_decoder() + ) + + # Checks if the current stage only has a vision encoder + self._encoder_only = ( + parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder() + ) + + # For TP only model both is_pp_first_stage and _is_pp_last_stage returns True + self.model_is_pipeline_parallel = not ( + parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() + ) + + def prep_inference_input( + self, + prompts_tokens: torch.Tensor, + num_img_embeddings_per_tile: int, + images: torch.Tensor, + num_tiles: torch.Tensor, + decoder_seq_length: int, + ): + """Prepares the inference input data. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + num_img_embeddings_per_tile (int): The number of image embeddings per tile + images (torch.Tensor): The image embeddings + num_tiles (torch.Tensor): The number of tiles for each input image + decoder_seq_length (int): The decoder sequence length + """ + inference_input = super().prep_inference_input(prompts_tokens) + + total_num_tiles = torch.sum(num_tiles).item() + num_img_embeddings = num_img_embeddings_per_tile * total_num_tiles + + batch_size, max_sequence_length = prompts_tokens.shape + self.inference_params = InferenceParams( + batch_size, max_sequence_length + num_img_embeddings + ) + + inference_input["images"] = images + inference_input["num_tiles"] = num_tiles + inference_input["num_img_embeddings"] = num_img_embeddings + inference_input["decoder_seq_length"] = decoder_seq_length + + return inference_input + + def get_batch_for_context_window( + self, + inference_input: Dict[str, Any], + context_start_position: int, + context_end_position: int, + ) -> Dict[str, Any]: + """Returns the inference data given context window + + This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data. + + Args: + inference_input (Dict[str, Any]): The inference input for the batch. + context_start_position (int): Start of the context window. During the first inference step it is mostly 0 + context_end_position (int): End of the context window. During the last inference step it will mostly be the max generated sequence length. + + Returns: + Dict[str, Any]: A dict of inputs that will be used by your model in the forward step + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + images = inference_input["images"] + num_tiles = inference_input["num_tiles"] + num_img_embeddings = inference_input["num_img_embeddings"] + decoder_seq_length = inference_input["decoder_seq_length"] + + tokens2use = tokens[:, context_start_position:context_end_position] + positions2use = position_ids[:, context_start_position:context_end_position] + + return { + "tokens": tokens2use, + "position_ids": positions2use, + "images": images, + "num_tiles": num_tiles, + "num_img_embeddings": num_img_embeddings, + "decoder_seq_length": decoder_seq_length, + } + + def _forward(self, inference_input: Dict[str, Any]): + """Runs a forward pass of the model. + + Args: + inference_input(Dict[str, Any]): The input data. + + Returns: + The model output logits. + """ + images = inference_input["images"] + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + num_image_tiles = inference_input["num_tiles"] + + output = self.model( + images, + tokens, + position_ids=position_ids, + attention_mask=None, + inference_params=self.inference_params, + num_image_tiles=num_image_tiles, + runtime_gather_output=True, + ) + if isinstance(output, tuple): + logits, _ = output + else: + logits = output + return logits + + def run_one_forward_step(self, inference_input: Dict[str, Any]) -> torch.Tensor: + tokens = inference_input["tokens"] + num_image_tokens = (tokens == self.model.module.image_token_index).sum().item() + num_img_embeddings = inference_input["num_img_embeddings"] + decoder_seq_length = inference_input["decoder_seq_length"] + num_tokens = tokens.size(1) + recv_buffer_seq_len = None + if num_image_tokens > 0: + # When there are image tokens and this stage only receives vision embeddings, + # adjust the recv buffer seq length to match the image embeddings sequence length. + # If there are image tokens and this stage receives full embeddings, make sure we + # compensate for expansion of image tokens. + # Note that this will set a recv_buffer_seq_len for the encoder stage, + # this length is irrelevant since that recv buffer is never allocated. + if self._recv_only_vision_embeds: + recv_buffer_seq_len = num_img_embeddings + else: + recv_buffer_seq_len = min( + num_img_embeddings + num_tokens - num_image_tokens, decoder_seq_length + ) + elif self._recv_only_vision_embeds: + # If this stage only receives vision embeddings and there are no image tokens + # we won't run the encoder and therefore shouldn't try to recv. + recv_buffer_seq_len = 0 + + # If the pipeline stage only has a vision encoder, then it only needs to + # run when there are image tokens + if not (self._encoder_only and num_image_tokens == 0): + output = super().run_one_forward_step( + inference_input, recv_buffer_seq_len=recv_buffer_seq_len + ) + else: + output = None + logits = output + + # On the first inference iteration, we compute image tokens. + # On every PP stage(although inference params should only matter for decoder), + # update the sequence length offset by the number of image tokens. + if num_tokens > 1 and num_image_tokens > 0: + if "image_tokens_count" not in self.inference_params.key_value_memory_dict: + self.inference_params.key_value_memory_dict["image_tokens_count"] = ( + num_img_embeddings + ) + + if num_img_embeddings + num_tokens - num_image_tokens > decoder_seq_length: + self.inference_params.sequence_len_offset += decoder_seq_length - num_tokens + else: + self.inference_params.sequence_len_offset += ( + self.inference_params.key_value_memory_dict["image_tokens_count"] + - num_image_tokens + ) + + return logits diff --git a/megatron/core/inference/model_inference_wrappers/t5/__init__.py b/megatron/core/inference/model_inference_wrappers/t5/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/inference/model_inference_wrappers/t5/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py new file mode 100644 index 0000000000..9dddb9ab8a --- /dev/null +++ b/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py @@ -0,0 +1,225 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from collections import deque +from typing import Any, Dict, List, Optional + +import numpy +import torch + +from megatron.core import tensor_parallel +from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset +from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( + AbstractModelInferenceWrapper, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.models.T5 import T5Model +from megatron.core.utils import get_attr_wrapped_model + + +# pylint: disable=line-too-long +class T5InferenceWrapper(AbstractModelInferenceWrapper): + """Constructor for the model inference wrapper + + The wrapper prepares the model for inference, provides the required input + data, and runs the forward pass + + Args: + model (T5Model): The T5 model (MCore or legacy) + inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed + use_local (bool): Whether the T5 model's transformer impl + is local (vs transformer_engine) + """ + + def __init__( + self, + model: T5Model, + inference_wrapper_config: InferenceWrapperConfig, + use_local: bool = False, + ): + super().__init__(model, inference_wrapper_config) + self.use_local = use_local + + def prep_inference_input( + self, + prompts_tokens: torch.Tensor, + encoder_prompts: Optional[List[str]] = None, + tokenizer: Any = None, + ) -> Dict[str, Any]: + """Prepares the inference input data. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len] + encoder_prompts (dict): List of string of encoder input prompts + tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text + + Returns: + A dict with all the inference input needed for the batch. + """ + # get max_sequence_length + max_sequence_length = get_attr_wrapped_model(self.model, "max_sequence_length") + + encoder_prompts_tokens_list = [ + self.tokenize_encoder_prompt(encoder_prompt, tokenizer) + for encoder_prompt in encoder_prompts + ] + batch_encoder_prompts_tokens = self.pad_encoder_prompts_tokens( + encoder_prompts_tokens_list, max_sequence_length, tokenizer + ) + + # create batch mask for encoder_prompt (self.batch_input_tokens) and + # decoder_input (prompts_tokens), similar to megatron/core/datasets/t5_dataset.py + decoder_prompts_tokens = prompts_tokens + encoder_prompts_tokens = batch_encoder_prompts_tokens + decoder_prompts_tokens_numpy = decoder_prompts_tokens.cpu().numpy() + encoder_prompts_tokens_numpy = encoder_prompts_tokens.cpu().numpy() + batch_mask_encoder = [] + batch_mask_decoder = [] + for i in range(len(prompts_tokens)): + mask_encoder = encoder_prompts_tokens_numpy[i] == tokenizer.pad + mask_decoder = decoder_prompts_tokens_numpy[i] == tokenizer.pad + batch_mask_encoder.append(mask_encoder) + batch_mask_decoder.append(mask_decoder) + batch_mask_encoder = torch.tensor(numpy.array(batch_mask_encoder)).cuda() + batch_mask_decoder = torch.tensor(numpy.array(batch_mask_decoder)).cuda() + + return { + "encoder_tokens": encoder_prompts_tokens, + "decoder_tokens": decoder_prompts_tokens, + "encoder_mask": batch_mask_encoder, + "decoder_mask": batch_mask_decoder, + } + + def tokenize_encoder_prompt(self, encoder_prompt: str, tokenizer) -> torch.Tensor: + """Utility to tokenize the encoder_prompt + + Args: + encoder_prompt (str): The encoder_prompt + tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string + + Returns: + torch.Tensor: Returns the tokenized prompt + """ + + # if there is the word "" in prompt, replacing it with special_additional_token, + # similar to processing step in megatron/core/datasets/t5_dataset.py + divided_encoder_prompt_list = encoder_prompt.split("") + masks_count = len(divided_encoder_prompt_list) - 1 + sentinels = deque(tokenizer.additional_special_tokens_ids) + + encoder_prompt_tokens = [] + for divided_encoder_prompt in divided_encoder_prompt_list: + divided_encoder_prompt_tokens = tokenizer.tokenize(divided_encoder_prompt) + encoder_prompt_tokens.extend(divided_encoder_prompt_tokens) + if masks_count > 0: + sentinel = sentinels.popleft() + encoder_prompt_tokens.extend([sentinel]) + masks_count -= 1 + + return encoder_prompt_tokens + + def pad_encoder_prompts_tokens( + self, encoder_prompts_tokens_list: List[List[int]], max_sequence_length: int, tokenizer + ) -> torch.Tensor: + """Method to pad input prompts + + Given a list of prompts, pad them all to uniform length + + Args: + encoder_prompts_tokens_list (List[List[int]]): A list containing the + encoder_input_tokens + max_sequence_length (int): Maximum of the length of the encoder inputs tokens + tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text + + Returns: + torch.Tensor: A torch tensor of shape [bs, max_sequence_length] + """ + + for encoder_prompt_tokens in encoder_prompts_tokens_list: + padding_size = max_sequence_length - len(encoder_prompt_tokens) + encoder_prompt_tokens.extend([tokenizer.pad] * padding_size) + + return torch.tensor(encoder_prompts_tokens_list).cuda() + + def get_batch_for_context_window( + self, + inference_input: Dict[str, Any], + context_start_position: int, + context_end_position: int, + ) -> Dict[str, Any]: + """Returns the inference data given context window + + This function gets called iteratively in a loop . Given the start and end context + positions , it extracts the appropriate data. + + Args: + inference_input (Dict[str, Any]): The inference input for the batch. + context_start_position (int): Start of the context window. During + the first inference step it is mostly 0 + context_end_position (int): End of the context window. During the + last inference step it will mostly be the max generated sequence length. + + Returns: + Dict: A dict of inputs that will be used by your model in the forward step + """ + + # T5 inference not yet support kv_cache + encoder_tokens2use = inference_input["encoder_tokens"] + decoder_tokens2use = inference_input["decoder_tokens"][:, :context_end_position] + encoder_mask2use = inference_input["encoder_mask"] + decoder_mask2use = inference_input["decoder_mask"][:, :context_end_position] + + # Configure attention mask based on different conditions + # (e.g., transformer-impl, TE versions, TE backends) + [encoder_mask2use, decoder_mask2use, encoder_decoder_mask2use] = ( + T5MaskedWordPieceDataset.config_attention_mask( + encoder_tokens2use, + decoder_tokens2use, + encoder_mask2use, + decoder_mask2use, + self.use_local, + ) + ) + + return { + "encoder_tokens": encoder_tokens2use, + "decoder_tokens": decoder_tokens2use, + "encoder_mask": encoder_mask2use, + "decoder_mask": decoder_mask2use, + "encoder_decoder_mask": encoder_decoder_mask2use, + } + + def forward_pass_without_pipeline_parallel( + self, inference_input: Dict[str, Any] + ) -> torch.Tensor: + """Utility to carry out simple forward pass for TP or no model parallel models + + Runs a very simple forward pass for model. Used in the case of models without + any parallelism or only tensor parallelism. + + Args: + inference_input (Dict[str, Any]): A dict containg the inputs for the gpt + model [tokens, position ids, attention mask] + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] + """ + encoder_tokens = inference_input["encoder_tokens"] + decoder_tokens = inference_input["decoder_tokens"] + encoder_mask = inference_input["encoder_mask"] + decoder_mask = inference_input["decoder_mask"] + encoder_decoder_mask = inference_input["encoder_decoder_mask"] + tokens = decoder_tokens + + # T5 inference not yet support kv_cache + logits = self.model( + encoder_tokens, + decoder_tokens, + encoder_mask, + decoder_mask, + encoder_decoder_mask, + inference_params=None, + ) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + return logits diff --git a/megatron/core/inference/modelopt_support/__init__.py b/megatron/core/inference/modelopt_support/__init__.py new file mode 100644 index 0000000000..885d2b3f01 --- /dev/null +++ b/megatron/core/inference/modelopt_support/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt). + +ModelOpt is a library comprising state-of-the-art model optimization techniques +including quantization and sparsity to compress model for efficient inference on +NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless +experience for users to optimize their Megatron-core models for inference. +More details on ModelOpt including installation and usage can be found at +https://github.com/NVIDIA/TensorRT-Model-Optimizer. +""" diff --git a/megatron/core/inference/modelopt_support/gpt/__init__.py b/megatron/core/inference/modelopt_support/gpt/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/inference/modelopt_support/gpt/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/inference/modelopt_support/gpt/model_specs.py b/megatron/core/inference/modelopt_support/gpt/model_specs.py new file mode 100644 index 0000000000..30f78b1395 --- /dev/null +++ b/megatron/core/inference/modelopt_support/gpt/model_specs.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from typing import Optional + +from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + + +# Use this spec for ModelOpt PTQ and TensorRT-LLM export +def get_gpt_layer_modelopt_spec( + num_experts: Optional[int] = None, + local_core_attention: bool = False, + moe_grouped_gemm: bool = False, + remap_te_layernorm: bool = False, + qk_layernorm: bool = False, +) -> ModuleSpec: + """Mix the native spec with TENorm. + + This is essentially the native local spec except for the layernorm implementation + is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex + has stopped supporting RMSNorm needed by llama. + """ + core_attention = DotProductAttention if local_core_attention else TEDotProductAttention + mlp = get_mlp_module_spec( + use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False + ) + sharded_state_dict_keys_map = {} + if remap_te_layernorm: + if num_experts: + sharded_state_dict_keys_map = { + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_' + } + else: + sharded_state_dict_keys_map = { + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + } + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=core_attention, + linear_proj=RowParallelLinear, + q_layernorm=TENorm if qk_layernorm else IdentityOp, + k_layernorm=TENorm if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + # Map TE-layernorm-fusion keys back + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) diff --git a/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py b/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py new file mode 100644 index 0000000000..15c3527c94 --- /dev/null +++ b/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from logging import getLogger + +import torch + +logger = getLogger(__name__) + + +def mcore_gpt_load_legacy_state_dict_pre_hook( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs +): + """Register a pre-hook to fix the state_dict key difference. + + This prehook is used when trying to load the legacy Megatron-LM GPTModel into its + megatron/core variant that uses native ParallelLinear and Transformer-Engine Norm. + Only this particular spec supports post-training quantization and TensorRT-LLM + config export through `nvidia-modelopt` package. + + Args: + state_dict: state dictionary + prefix: module name prefix + local_metadata: local metatdata + strict: whether is in strict mode + missing_keys: missing state dict keys + unexpected_keys: unexpected state dict keys + error_msgs: error messages + """ + if "modelopt_state" in state_dict: + state_dict.pop("modelopt_state") + + if "language_model" in state_dict: + language_model_state_dict = state_dict.pop("language_model") + if "embedding" in language_model_state_dict: + if "word_embeddings" in language_model_state_dict["embedding"]: + for key, param in language_model_state_dict["embedding"]["word_embeddings"].items(): + state_dict.update({"embedding.word_embeddings." + key: param}) + if "position_embeddings" in language_model_state_dict["embedding"]: + for key, param in language_model_state_dict["embedding"][ + "position_embeddings" + ].items(): + state_dict.update({"embedding.position_embeddings." + key: param}) + if "transformer" in language_model_state_dict: + for key, param in language_model_state_dict["transformer"].items(): + state_dict.update({"decoder." + key: param}) + else: + for key, param in language_model_state_dict["encoder"].items(): + state_dict.update({"decoder." + key: param}) + if "output_layer" in language_model_state_dict: + for key, param in language_model_state_dict["output_layer"].items(): + state_dict.update({"output_layer." + key: param}) + + if torch.distributed.get_rank() == 0: + logger.info("ModelOptGPTModel {}".format(state_dict.keys())) + + module_name_rewrite_list = [ + ("input_norm", "input_layernorm"), + (".attention.query_key_value", ".self_attention.linear_qkv"), + (".attention.dense", ".self_attention.linear_proj"), + ("self_attention.query_key_value", "self_attention.linear_qkv"), + ("self_attention.dense", "self_attention.linear_proj"), + ("post_attention_layernorm", "pre_mlp_layernorm"), + ("post_attention_norm", "pre_mlp_layernorm"), + ("dense_h_to_4h", "linear_fc1"), + ("dense_4h_to_h", "linear_fc2"), + ("final_norm", "final_layernorm"), + ] + + key_rewrite_list = [] + + for key, _ in state_dict.items(): + for old_name, new_name in module_name_rewrite_list: + if old_name in key: + key_rewrite_list += [(key, key.replace(old_name, new_name))] + + for old_key, new_key in key_rewrite_list: + if torch.distributed.get_rank() == 0: + logger.info("replace {} with {}".format(old_key, new_key)) + state_dict[new_key] = state_dict[old_key] + state_dict.pop(old_key) + + +def mcore_gpt_load_te_state_dict_pre_hook( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs +): + """Register a pre-hook to fix the state_dict key difference of. + + This prehook is used when trying to load the megatron/core GPTModel that uses a + fused Transformer-Engine ParallelLinear into the variant that uses native ParallelLinear + and Transformer-Engine Norm (effectively to restore the fusion). + Only this particular spec supports post-training quantization and TensorRT-LLM + config export through `nvidia-modelopt` package. + + Args: + state_dict: state dictionary + prefix: module name prefix + local_metadata: local metatdata + strict: whether is in strict mode + missing_keys: missing state dict keys + unexpected_keys: unexpected state dict keys + error_msgs: error messages + """ + if "modelopt_state" in state_dict: + state_dict.pop("modelopt_state") + + key_with_te_extra_state_to_pop = [] + + for key, _ in state_dict.items(): + if "_extra_state" in key: + key_with_te_extra_state_to_pop += [key] + + for key in key_with_te_extra_state_to_pop: + state_dict.pop(key) + + module_name_rewrite_list = [ + ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), + ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), + ("mlp.linear_fc1.layer_norm_weight", "pre_mlp_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "pre_mlp_layernorm.bias"), + ] + + key_rewrite_list = [] + + for key, _ in state_dict.items(): + for old_name, new_name in module_name_rewrite_list: + if old_name in key: + key_rewrite_list += [(key, key.replace(old_name, new_name))] + + for old_key, new_key in key_rewrite_list: + if torch.distributed.get_rank() == 0: + logger.info("replace {} with {}".format(old_key, new_key)) + state_dict[new_key] = state_dict[old_key] + state_dict.pop(old_key) diff --git a/megatron/core/inference/modelopt_support/mamba/__init__.py b/megatron/core/inference/modelopt_support/mamba/__init__.py new file mode 100644 index 0000000000..e76ed74857 --- /dev/null +++ b/megatron/core/inference/modelopt_support/mamba/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/inference/modelopt_support/mamba/model_specs.py b/megatron/core/inference/modelopt_support/mamba/model_specs.py new file mode 100755 index 0000000000..7fc1c8bd01 --- /dev/null +++ b/megatron/core/inference/modelopt_support/mamba/model_specs.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules +from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules +from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + + +# Use this spec for ModelOpt PTQ and TensorRT-LLM export +def get_mamba_stack_modelopt_spec( + local_core_attention: bool = False, remap_te_layernorm: bool = False +) -> ModuleSpec: + """Mix the native spec with TENorm. + + This is essentially the native local spec except for the layernorm implementation + is using TENorm from Transformer-Engine. + """ + mamba_state_dict_keys_map = {} + transformer_state_dict_keys_map = {} + if remap_te_layernorm: + mamba_state_dict_keys_map = {'norm.': 'mixer.in_proj.layer_norm_'} + transformer_state_dict_keys_map = { + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + } + + mamba_layer = ModuleSpec( + module=MambaLayer, + submodules=MambaLayerSubmodules( + norm=TENorm, + mixer=ModuleSpec( + module=MambaMixer, + submodules=MambaMixerSubmodules( + in_proj=ColumnParallelLinear, out_proj=RowParallelLinear + ), + ), + mamba_bda=get_bias_dropout_add, + sharded_state_dict_keys_map=mamba_state_dict_keys_map, + ), + ) + + core_attention = DotProductAttention if local_core_attention else TEDotProductAttention + attention_layer = ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=core_attention, + linear_proj=RowParallelLinear, + ), + ), + self_attn_bda=get_bias_dropout_add, + sharded_state_dict_keys_map=transformer_state_dict_keys_map, + ), + ) + + mlp_layer = ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + pre_mlp_layernorm=TENorm, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map=transformer_state_dict_keys_map, + ), + ) + + return ModuleSpec( + module=MambaStack, + submodules=MambaStackSubmodules( + mamba_layer=mamba_layer, attention_layer=attention_layer, mlp_layer=mlp_layer + ), + ) diff --git a/megatron/core/inference/sampling_params.py b/megatron/core/inference/sampling_params.py new file mode 100644 index 0000000000..0bb81bd3b0 --- /dev/null +++ b/megatron/core/inference/sampling_params.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass + + +@dataclass +class SamplingParams: + """Inference parameters sent along with the prompts. + This class contains request-level attributes that control the sampling techniques used when + generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level + inference attributes such as the maximum sequence length, and contains the KV cache. + + For an explanation of these parameters refer to this blog + https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and- + temperature-parameters-ed6a31313910 + """ + + temperature: float = 1.0 + top_k: int = 0 + top_p: float = 0.0 + return_log_probs: bool = False + return_segments: bool = False # Whether to return individually detokenized tokens + num_tokens_to_generate: int = 30 + + def add_attributes(self, attribute_value_pair: dict): + """Utility to add more attributes to sampling params + + Use this method to pass in a custom dictionary to add more sampling parameter attributes. + c = SamplingParams + c.add_attributes({'min_length':4, 'eod_id':153}) + + Args: + attribute_value_pair (dict): A dictionary containing attributes as the key names and + their values as the values. + """ + for key, value in attribute_value_pair.items(): + setattr(self, key, value) diff --git a/megatron/core/inference/scheduler.py b/megatron/core/inference/scheduler.py new file mode 100644 index 0000000000..7300d482d0 --- /dev/null +++ b/megatron/core/inference/scheduler.py @@ -0,0 +1,175 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import functools +import time +import typing +from collections import OrderedDict +from typing import Dict, Optional, Type, Union + +import torch + +from megatron.core.inference.async_stream import AsyncStream +from megatron.core.inference.inference_request import InferenceRequest, Status +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.utils import Counter + + +class Scheduler: + """Scheduler for handling requests to inference engine + + This class is responsible for handing of all the incomign requests + + Args: + max_batch_size (int): The max batch size that we can pass to the + inference engine at a time. + request_type (InferenceRequest): The class to use for instantiating new requests. + """ + + def __init__(self, max_batch_size): + self.max_batch_size = max_batch_size + self.requests: Dict[str, InferenceRequest] = OrderedDict() + self.streams: Dict[str, AsyncStream] = OrderedDict() + self.active_request_pool: Dict[str, InferenceRequest] = OrderedDict() + self.waiting_request_pool: Dict[str, InferenceRequest] = OrderedDict() + self.completed_request_pool: Dict[str, InferenceRequest] = OrderedDict() + self.request_counter = Counter() + + def get_new_request_id(self) -> str: + """Gets a new request id""" + request_id = str(next(self.request_counter)) + return request_id + + def add_request( + self, + prompt: Optional[str] = None, + prompt_tokens: Optional[torch.Tensor] = None, + encoder_prompt: Optional[str] = None, + inference_parameters: Optional[SamplingParams] = None, + arrival_time: Optional[float] = None, + streaming: bool = False, + inference_request: Optional[InferenceRequest] = None, + ) -> str: + """Add an incoming request + + This method will add the request to either the active pool or the waiting pool + depending on the batch size. + + Args: + prompt (str): Input prompt string + prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized + encoder_prompt (str): Encoder input string + inference_parameters (SamplingParams): The inference parameters + arrival_time (float, optional): The incoming request time. Defaults to None. + streaming (bool, optional): Whether to asynchronously stream tokens for this request. + inference_request (InferenceRequest, optional): A fully constructed request. + Defaults to None. + + Returns: + The request_id for the new request. + """ + status = ( + Status.ACTIVE_BUT_NOT_GENERATING_TOKENS + if len(self.active_request_pool) < self.max_batch_size + else Status.WAITING_IN_QUEUE + ) + + if inference_request is None: + assert prompt is not None + assert prompt_tokens is not None + + request_id = self.get_new_request_id() + + if arrival_time is None: + arrival_time = time.time() + + inference_request = InferenceRequest( + request_id=request_id, + prompt=prompt, + inference_parameters=inference_parameters, + arrival_time=arrival_time, + prompt_tokens=prompt_tokens, + status=status, + encoder_prompt=encoder_prompt, + ) + else: + request_id = inference_request.request_id + inference_request.status = status + if inference_request.arrival_time is None: + inference_request.arrival_time = time.time() + + self.requests[request_id] = inference_request + + if streaming: + abort_request = functools.partial(self.abort_request, request_id=request_id) + self.streams[request_id] = AsyncStream(request_id, abort_request) + + if status == status.ACTIVE_BUT_NOT_GENERATING_TOKENS: + self.active_request_pool[request_id] = inference_request + else: + self.waiting_request_pool[request_id] = inference_request + + return request_id + + def have_requests_pending(self) -> bool: + """Method to check if there are requests pending + + This method returns False only when there are no active requests or waiting requests. + """ + num_requests_pending = len(self.active_request_pool) + len(self.waiting_request_pool) + return num_requests_pending > 0 + + def add_earliest_waiting_request_to_active_pool(self): + """Utility to add the waiting request to active pool + + This method will add the earliest request (FIFO) that is in the waiting request + pool to the active request pool. + """ + assert ( + len(self.active_request_pool) < self.max_batch_size + ), "Active request pool is already full. Cant add any more requests" + if len(self.waiting_request_pool) > 0: + (earliest_waiting_request_request_id, earliest_waiting_request) = ( + self.waiting_request_pool.popitem(last=False) + ) + earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS + self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request + + def update_requests_pools( + self, result_dict: Optional[typing.OrderedDict[str, InferenceRequest]] = None + ): + """Update request pool status + + This method will full up the active request pool, if it has less than max batch size + elements from the waiting request pool. + If provided with a request dict, it will put the completed requests into the completed + request pool and add waiting request into active pool. + + Args: + result (typing.OrderedDict[str, InferenceRequest], optional): The result returned + by the engine. A dictionary with keys as the request ids, and values as the + requests. Defaults to None + """ + for result_request_id in list(result_dict.keys()): + active_request = self.active_request_pool[result_request_id] + + # If a request has completed put it into the completed request pool. + if active_request.status == Status.COMPLETED: + completed_request = self.active_request_pool.pop(result_request_id) + self.completed_request_pool[result_request_id] = completed_request + + # If the active request pool is not full, add waiting requests in FIFO order + while ( + len(self.active_request_pool) < self.max_batch_size + and len(self.waiting_request_pool) > 0 + ): + self.add_earliest_waiting_request_to_active_pool() + + def abort_request( + self, + request_id: str, + *, + exception: Optional[Union[BaseException, Type[BaseException]]] = None + ): + """Cancels the given request""" + stream = self.streams.get(request_id, None) + if stream is not None: + stream.finish(exception=exception) diff --git a/megatron/core/inference/text_generation_controllers/__init__.py b/megatron/core/inference/text_generation_controllers/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/inference/text_generation_controllers/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py new file mode 100644 index 0000000000..d10dfe1a01 --- /dev/null +++ b/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from typing import Any, Dict, OrderedDict + +import torch + +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) + + +class EncoderDecoderTextGenerationController(TextGenerationController): + """The text generation controller for encoder-decoder architecture + + This class inherits from TextGenerationController, adding features + relating to encoder input encoder_prompt + + """ + + def prep_inference_input( + self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest] + ) -> Dict[str, Any]: + """Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] + active_requests (OrderedDict[str, InferenceRequest]): The input active requests + + Returns: + A dict of the inference input for the current batch. + """ + encoder_prompts = list( + map(lambda request: request.encoder_prompt, active_requests.values()) + ) + + return self.inference_wrapped_model.prep_inference_input( + prompts_tokens, encoder_prompts, tokenizer=self.tokenizer + ) diff --git a/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py new file mode 100644 index 0000000000..340cadb48a --- /dev/null +++ b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py @@ -0,0 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import + TextGenerationController as SimpleTextGenerationController, +) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py new file mode 100644 index 0000000000..f752f06f15 --- /dev/null +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -0,0 +1,674 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import concurrent +import copy +import functools +from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union + +import torch +import torch.nn.functional as F + +from megatron.core import parallel_state +from megatron.core.inference.async_stream import AsyncStream +from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage +from megatron.core.inference.inference_request import InferenceRequest, Status +from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( + AbstractModelInferenceWrapper, +) +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.transformer.cuda_graphs import create_cudagraphs +from megatron.core.utils import get_model_config + + +class TextGenerationController: + """The text generation controller (the main sampling loop) + + This class tokenizes the input, runs inference, samples from logits, and detokenizes the output. + + Args: + inference_wrapped_model (AbstractModelInferenceWrapper): A model that + is wrapped using the specs given in the abstract_model_inference_wrapper.py + tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts + """ + + def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer): + self.inference_wrapped_model = inference_wrapped_model + self.tokenizer = tokenizer + + # For models without pipeline parallelism, is_first_stage and is_last_stage returns True + self.model_is_pipeline_parallel = not ( + parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() + ) + + def tokenize_prompt( + self, prompt: str, add_BOS: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Utility to tokenize the input prompts + + Args: + prompt (str): The input prompt + + Returns: + torch.Tensor: Returns the tokenized prompt + """ + prompt_tokens = self.tokenizer.tokenize(prompt) + + if add_BOS: + prompt_tokens = [self.tokenizer.bos] + prompt_tokens + + return prompt_tokens + + def detokenize_generations( + self, + tokens_gpu_tensor: torch.Tensor, + lengths_gpu_tensor: torch.Tensor, + detokenize_segments: bool, + ) -> tuple[str, Optional[List[List[str]]]]: + """Detokenize the generated tokens. + + Args: + tokens_gpu_tensor (torch.Tensor): Tensor containing the tokens + lengths_gpu_tensor (torch.Tensor): Tensor containing the lengths of each sequence + detokenize_segments (bool): If True, returns individually detokenized tokens. If False, + returns None as second element. Helpful for understanding per-token boundaries in + generated text. + + Returns: + tuple[str, List[str] | None]: A tuple containing: + - str: The complete detokenized text + - List[str] | None: List of segmented tokens if detokenize_segments is True, else None + """ + # TODO(helenn): Unify with `detokenize_generations` from legacy textgen path + + if not detokenize_segments: + tokens = tokens_gpu_tensor.cpu().numpy().tolist() + return self.tokenizer.detokenize(tokens), None + + prompts_plus_generations: List[str] = [] + prompts_plus_generations_segments: List[List[str]] = [] + + tokens_gpu_tensor = torch.unsqueeze(tokens_gpu_tensor, 0) + tokens = tokens_gpu_tensor.cpu().numpy().tolist() + lengths = lengths_gpu_tensor.cpu().numpy().tolist() + + for sequence_tokens, length in zip(tokens, lengths): + sequence_tokens = sequence_tokens[:length] + detok_str = self.tokenizer.detokenize(sequence_tokens) + prompts_plus_generations.append(detok_str) + offsets = self.tokenizer.offsets(sequence_tokens, detok_str) + words = [ + detok_str[start:end] for start, end in zip(offsets, offsets[1:] + [len(detok_str)]) + ] + + prompts_plus_generations_segments.append(words) + + text = self.tokenizer.detokenize(tokens[0]) + + return text, prompts_plus_generations_segments + + def sample_from_logits( + self, + last_token_logits: torch.Tensor, + sampling_params: Optional[SamplingParams] = None, + vocab_size: Optional[int] = None, + **kwargs, + ) -> torch.Tensor: + """Samples the logits to generate outputs + + Given the logits of the last token, this function samples it + according to the parameters defined in sampling_params + and returns the samples + + Args: + last_token_logits (torch.Tensor): The last token logits. A tensor of + size [batch_size, vocab_size] + sampling_params (SamplingParams): The parameters to use for inference. + vocab_size (int): Obtained from the tokenizer. Defaults to None + + Returns: + torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements + """ + + if kwargs.get('common_inference_params'): + sampling_params = kwargs['common_inference_params'] + + top_p = sampling_params.top_p + top_k = sampling_params.top_k + temperature = sampling_params.temperature + + assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero' + assert top_p <= 1.0, 'top-p should be in (0,1]' + + def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf.""" + filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(filter_, float('-Inf')) + + def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf.""" + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Filteration based on the cumulative sum. + filter_ = cumulative_probs > top_p + # This shift by 1 is weird and I cannot justify it. This existed + # in the original implementation: + # https://github.com/ari-holtzman/degen/blob/master/gen.py + # and I guess it is needed so keeping it for now. + filter_[:, 1:] = filter_[:, :-1].clone() + # Make sure we at least have one token to select from. + filter_[..., 0] = 0 + + # Fill in the filtered part + filter_ = filter_.scatter(1, sorted_indices, filter_) + logits.masked_fill_(filter_, float('-Inf')) + + # Greedy sampling + if top_k == 1: + sampled_logits = torch.argmax(last_token_logits, dim=-1) + else: + last_token_logits = last_token_logits.clone() + if temperature != 1.0: + last_token_logits.div_(temperature) + + if top_k > 1: + assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.' + if vocab_size: + assert top_k < vocab_size, 'top-k is larger than vocab size.' + modify_logits_for_top_k_filtering(last_token_logits, top_k) + + elif top_p > 0.0: + modify_logits_for_top_p_filtering(last_token_logits, top_p) + + # After filtering, we need to recalculate the distribution. + probabilities = last_token_logits.softmax(dim=-1) + sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1) + + # If vocab size is provided, make sure the samples are in in the range [0, vocab-size). + if vocab_size: + sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1)) + return sampled_logits + + def update_generation_status( + self, + updated_prompts_tokens: torch.Tensor, + generation_started: torch.Tensor, + current_context_end_position: int, + is_generation_done_tensor: torch.Tensor, + generated_sequence_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Checks which prompts have reached an end condition + + We check which prompts have reached an end condition and set the corresponding + flags of the is_generation_done_tensor to True. The generated sequence lengths + increase as we keep generating, until that prompts hits an end condition. The + generation_started tensor determines which prompts have started generating. + + Args: + updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest + generated tokens. A tensor of shape [batch_size, max_seq_len] + (i.e max_seq_len = max_prompt_len + tokens_to_generate) + generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True + indicates the prompt at that index has started generating tokens. + current_context_end_position (int): An integer indicating which position to + extract from the prompts tokens to get the latest generated tokens. + is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size]. + True indicates the prompt at that index has reached end condition. + generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size]. + Each value represents the generated sequence lengths for that prompt. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Returns the boolean + is_generation_done_tensor and the generated_sequence_lengths after updating it + """ + latest_samples = updated_prompts_tokens[:, current_context_end_position] + # Make sure we are checking eod criterion only for prompts that have started generating + # (i.e) We only look at the generated tokenns and not the input tokens. + reached_eod = (latest_samples == self.tokenizer.eod) & generation_started + is_generation_done_tensor = is_generation_done_tensor | reached_eod + # We increment generated sequence lengths when that prompt has not hit the + # EOD and generation has started + generated_sequence_lengths += ~is_generation_done_tensor & generation_started + + return is_generation_done_tensor, generated_sequence_lengths.int() + + def pad_input_prompt_tokens( + self, + batch_prompt_tokens_list: List[List[int]], + max_prompt_length_in_batch: int, + num_tokens_to_generate: int, + ) -> torch.Tensor: + """Method to pad input prompts + + Given a list of prompts, pad them all to uniform length + + Args: + batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens + max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens + num_tokens_togenerate (int): The number of tokens to generate for each prompt + + Returns: + torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e) + max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate, + """ + max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate + + for prompt_tokens in batch_prompt_tokens_list: + padding_size = max_seq_len - len(prompt_tokens) + prompt_tokens.extend([self.tokenizer.eod] * padding_size) + + return torch.tensor(batch_prompt_tokens_list, device=torch.cuda.current_device()) + + def generate_output_tokens_dynamic_batch( + self, active_requests: OrderedDict[str, InferenceRequest] + ) -> OrderedDict[str, InferenceRequest]: + """Utility to generate the output tokens and probabilities for the prompts + + This utility generates the output tokens for a dynamic batch. It will run one forward step + at a time, and pass control back to the engine, which will update the request pool and call + this method again. + + Args: + active_requests (OrderedDict[str, InferenceRequest]): The input active requests. + + Returns: + OrderedDict[str, InferenceRequest]: The result for each of the incoming requests + after running one forward step. + """ + raise Exception("Not implemented yet") + + def generate_all_output_tokens_static_batch( + self, + active_requests: OrderedDict[str, InferenceRequest], + active_streams: Optional[OrderedDict[str, AsyncStream]] = None, + ) -> OrderedDict[str, InferenceRequest]: + """Utility to generate the all the output tokens and probabilities for the prompts . + + This utility generates the output tokens for a static batch. It runs the forward steps till + all prompts complete generation, updates the status of these requests to completed, adds + the generated result and returns these requests + + Args: + active_requests (OrderedDict[str, InferenceRequest]): The input active requests. + + Returns: + OrderedDict[str, InferenceRequest]: The result for each of the incoming requests + """ + assert all(request.prompt_tokens is not None for request in active_requests.values()) + + # Perform a deep copy so that the request prompt tokens do not get modified. + batch_prompt_tokens_list: List[List[int]] = list( + map( + lambda request: copy.deepcopy(request.prompt_tokens), # type: ignore[arg-type] + active_requests.values(), + ) + ) + prompt_lengths_in_batch = torch.tensor( + [len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list], + device=torch.cuda.current_device(), + ) + max_prompt_length_in_batch = max(prompt_lengths_in_batch) + min_prompt_length_in_batch = min(prompt_lengths_in_batch) + + # For batch inference the inference params are the same for all request + sampling_params: SamplingParams = list(active_requests.values())[0].inference_parameters + + # max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate + batch_prompt_tokens = self.pad_input_prompt_tokens( + batch_prompt_tokens_list, + max_prompt_length_in_batch=max_prompt_length_in_batch, + num_tokens_to_generate=sampling_params.num_tokens_to_generate, + ) + batch_size, max_sequence_length = batch_prompt_tokens.shape + + # Verify that output sequence length is within configured limit + # TODO(ksanthanam): Raise TokenOverflowError once !2518 is merged + inference_max_sequence_length = ( + self.inference_wrapped_model.inference_wrapper_config.inference_max_seq_length + ) + assert max_sequence_length <= inference_max_sequence_length, ( + f"Maximum allowed sequence length was set to {inference_max_sequence_length} tokens " + f"but requested generation of {max_sequence_length} tokens" + ) + + # Pre allocate log probs tensor + output_log_probs = None + if sampling_params.return_log_probs: + output_log_probs = torch.empty( + (batch_size, max_sequence_length - 1), + dtype=torch.float32, + device=torch.cuda.current_device(), + ) + + # An array to check which of the prompts have reached end of generation condition + is_generation_done_tensor = torch.zeros( + batch_size, dtype=torch.bool, device=torch.cuda.current_device() + ) + + # An array to act as a counter to keep track of generated sequence lengths + generated_sequence_lengths = torch.zeros( + batch_size, device=torch.cuda.current_device() + ).cuda() + + # Use padded vocab size because tokenizer vocab size might not include padding + # to nearest power of 2 + vocab_size = self.inference_wrapped_model.inference_wrapper_config.padded_vocab_size + + # Check whether CUDA graphs are enabled + enable_cuda_graph = get_model_config(self.inference_wrapped_model.model).enable_cuda_graph + + streaming_enabled = active_streams is not None and len(active_streams) > 0 + if streaming_enabled: + # Start a separate thread for streaming tokens to avoid blocking the + # main computation + streaming_idx: List[int] = [ + i + for (i, request_id) in enumerate(active_requests.keys()) + if request_id in active_streams + ] + streaming_request_ids: List[str] = list(active_streams.keys()) + streams: List[AsyncStream] = list(active_streams.values()) + streaming_requests: List[InferenceRequest] = [ + active_requests[request_id] for request_id in streaming_request_ids + ] + streaming_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + stream_tokens = functools.partial(self.stream_tokens, sampling_params) + + with torch.no_grad(): + + self.inference_wrapped_model.prep_model_for_inference( + prompts_tokens=batch_prompt_tokens + ) + + inference_input: Dict[str, Any] = self.prep_inference_input( + prompts_tokens=batch_prompt_tokens, active_requests=active_requests + ) + + assert ( + not self.inference_wrapped_model.inference_params.decode_mode + ), f"Generation must start in prefill mode" + + context_start_position = 0 + # Pick the context window that we need to pass through the network. + for context_end_position in range(min_prompt_length_in_batch, max_sequence_length): + + inference_input_for_context_window: Dict[str, Any] = ( + self.inference_wrapped_model.get_batch_for_context_window( + inference_input, context_start_position, context_end_position + ) + ) + + # Disable attention mask when using CUDA graphs for decode + if ( + enable_cuda_graph + and self.inference_wrapped_model.inference_params.decode_mode + and "attention_mask" in inference_input_for_context_window + ): + inference_input_for_context_window["attention_mask"] = None + + # Returns the final logits of shape [batch_size, context_length, vocab_size] + # Note: This is returned in all TP ranks or last PP stage in PP models + logits = self.inference_wrapped_model.run_one_forward_step( + inference_input_for_context_window + ) + + if enable_cuda_graph: + create_cudagraphs() + + if self.model_is_pipeline_parallel: + context_length = context_end_position - context_start_position + logits = broadcast_from_last_pipeline_stage( + [batch_size, context_length, vocab_size], + dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype, + tensor=logits, + ) + + # Indicates which of the input prompts have started generating tokens. + # A 1D boolean tensor with [batch_size] elements (i.e) The shortest + # prompts will start generating first and so on + generation_started = prompt_lengths_in_batch <= context_end_position + last_token_logits = logits[:, -1, :] + sampled_logits = self.sample_from_logits( + last_token_logits, sampling_params, vocab_size + ) + + # Substitute the sampled logits only for the prompts that + # have started generating tokens + batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[ + generation_started + ] + + if sampling_params.return_log_probs: + log_probs = F.log_softmax(logits, dim=2) + indices = torch.unsqueeze( + batch_prompt_tokens[ + :, (context_start_position + 1) : (context_end_position + 1) + ], + 2, + ) + # Get the log probabilities for only the prompt tokens + assert output_log_probs is not None + output_log_probs[:, context_start_position:context_end_position] = torch.gather( + log_probs, 2, indices + ).squeeze(2) + + context_start_position = context_end_position + + # Check end of generation status for each tensor + # and update generated sequence lengths + (is_generation_done_tensor, generated_sequence_lengths) = ( + self.update_generation_status( + updated_prompts_tokens=batch_prompt_tokens, + generation_started=generation_started, + current_context_end_position=context_end_position, + is_generation_done_tensor=is_generation_done_tensor, + generated_sequence_lengths=generated_sequence_lengths, + ) + ) + + # Stream intermediate outputs + if streaming_enabled: + streaming_executor.submit( + stream_tokens, + streaming_request_ids, + streaming_requests, + streams, + generation_started[streaming_idx].cpu(), + is_generation_done_tensor[streaming_idx].cpu(), + batch_prompt_tokens[streaming_idx].cpu(), + prompt_lengths_in_batch[streaming_idx].cpu(), + generated_sequence_lengths[streaming_idx].cpu(), + ( + output_log_probs[streaming_idx].cpu() + if output_log_probs is not None + else [None] * len(streaming_idx) + ), + ) + + # Boolean flag indicating if all prompts are finished + all_prompts_done = torch.all(is_generation_done_tensor) + if all_prompts_done: + break + + # Change to decode mode if all prefill is complete + if torch.all(generation_started): + self.inference_wrapped_model.inference_params.enable_decode_mode() + + # Close all streams + if streaming_enabled: + streaming_executor.shutdown() + for stream in streams: + stream.finish() + + # Include all the generated tokens + batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)] + if sampling_params.return_log_probs: + assert output_log_probs is not None + output_log_probs = output_log_probs[:, :context_end_position] + + generated_sequence_lengths[ + generated_sequence_lengths > sampling_params.num_tokens_to_generate + ] = sampling_params.num_tokens_to_generate + + for idx, request in enumerate(active_requests.values()): + input_prompt_length = int(prompt_lengths_in_batch[idx]) + # Shorter prompts might have generated more than required tokens. So we trim them down + required_sequence_length = int( + min(generated_sequence_lengths[idx], sampling_params.num_tokens_to_generate) + ) + # Extract only the generated tokens + required_result_tokens = batch_prompt_tokens_with_generations[ + idx, input_prompt_length : (input_prompt_length + required_sequence_length) + ] + generated_sequence_lengths = generated_sequence_lengths.to(dtype=torch.int32) + request.generated_sequence_lengths = generated_sequence_lengths.to(dtype=torch.int32) + request.generated_length = required_sequence_length + request.generated_tokens = required_result_tokens + + request.prompt_log_probs = ( + None + if output_log_probs is None + else output_log_probs[idx, :input_prompt_length].cpu().numpy().tolist() + ) + + request.generated_log_probs = ( + None + if output_log_probs is None + else output_log_probs[ + idx, + input_prompt_length - 1 : (input_prompt_length + required_sequence_length - 1), + ] + .cpu() + .numpy() + .tolist() + ) + request.status = Status.COMPLETED + + text, segments = self.detokenize_generations( + batch_prompt_tokens_with_generations[idx], + input_prompt_length + generated_sequence_lengths, + sampling_params.return_segments, + ) + request.text = text # Inference server returns prompts & generations together + if sampling_params.return_segments: + request.segments = segments[0] + request.generated_text = text[len(request.prompt) :] + return active_requests + + def prep_inference_input( + self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest] + ) -> Dict[str, Any]: + """Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] + active_requests (OrderedDict[str, InferenceRequest]): The input active requests + + Returns: + A dict of the inference input for the current batch. + """ + return self.inference_wrapped_model.prep_inference_input(prompts_tokens) + + def stream_tokens( + self, + sampling_params: SamplingParams, + request_ids: List[str], + requests: List[InferenceRequest], + streams: List[AsyncStream], + generation_started: List[bool], + is_generation_done: List[bool], + tokens: torch.Tensor, + prompt_lengths: List[int], + generated_lengths: List[int], + output_log_probs: Union[torch.Tensor, None], + ): + """Asynchronously streams tokens for the given requests. + + Args: + sampling_params (SamplingParams): The sampling parameters. + request_ids (List[str]): The request IDs. + request (List[InferenceRequest]): The requests. + stream (List[AsyncStream]): The streams over which to send tokens. + generation_started (List[bool]): Whether the decode step has started. + is_generation_done (List[bool]): Whether generation has completed. + tokens (torch.Tensor): The tokens for this request. + prompt_lengths (List[int]): The number of prompt tokens for each request. + generated_lengths (List[int]): The number of output tokens for each request. + output_log_probs (torch.Tensor, optional): The log probs for each request. + """ + + def stream_token( + request_id: str, + request: InferenceRequest, + stream: AsyncStream, + generation_started: bool, + is_generation_done: bool, + tokens: torch.Tensor, + prompt_length: int, + generated_length: int, + output_log_probs: Union[torch.Tensor, None], + ): + """Asynchronously streams a token for the given request.""" + + if not generation_started or stream.finished: + return + + num_tokens_to_generate = sampling_params.num_tokens_to_generate + return_segments = sampling_params.return_segments + detokenize_streaming_text = not getattr( + sampling_params, "no_detokenize_streaming_text", False + ) + + generated_tokens = tokens[prompt_length : prompt_length + generated_length] + + if detokenize_streaming_text: + generated_text, generated_segments = self.detokenize_generations( + generated_tokens, prompt_length + generated_length, return_segments + ) + else: + generated_text = "" + generated_segments = [] + + if output_log_probs is not None: + generated_log_probs = ( + output_log_probs[prompt_length - 1 : prompt_length + generated_length - 1] + .cpu() + .numpy() + .tolist() + ) + else: + generated_log_probs = None + + stream.put( + InferenceRequest( + request_id=request_id, + prompt=request.prompt, + inference_parameters=request.inference_parameters, + prompt_tokens=request.prompt_tokens, + arrival_time=request.arrival_time, + status=request.status, + encoder_prompt=request.encoder_prompt, + generated_text=generated_text, + generated_segments=generated_segments, + generated_tokens=generated_tokens, + generated_log_probs=generated_log_probs, + generated_length=generated_length, + ) + ) + + if is_generation_done or generated_length == num_tokens_to_generate: + stream.finish() + + ret = map( + stream_token, + request_ids, + requests, + streams, + generation_started, + is_generation_done, + tokens, + prompt_lengths, + generated_lengths, + output_log_probs, + ) + list(ret) diff --git a/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py new file mode 100644 index 0000000000..517c2b2bad --- /dev/null +++ b/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from typing import OrderedDict + +import torch + +from megatron.core.inference.inference_request import InferenceRequest, VLMInferenceRequest +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) + + +class VLMTextGenerationController(TextGenerationController): + """The text generation controller for VLMs""" + + def prep_inference_input( + self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest] + ): + """Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long + + Currently only supports batch size 1 inference. + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] + active_requests (OrderedDict[str, InferenceRequest]): The input active requests + """ + assert len(active_requests) == 1, f"VLM inference currently only supports batch size 1" + + request = list(active_requests.values())[0] + + assert isinstance( + request, VLMInferenceRequest + ), f"Found inference request of type {type(request)}, expected VLMInferenceRequest" + + return self.inference_wrapped_model.prep_inference_input( + prompts_tokens, + request.num_img_embeddings_per_tile, + request.imgs, + request.num_tiles, + request.decoder_seq_length, + ) diff --git a/megatron/core/inference/utils.py b/megatron/core/inference/utils.py new file mode 100644 index 0000000000..bdb1021ef5 --- /dev/null +++ b/megatron/core/inference/utils.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +class Counter: + """A simple counter class + + This class is responsible for assigning request ids to incoming requests + """ + + def __init__(self, start: int = 0) -> None: + self.counter = start + + def __next__(self) -> int: + i = self.counter + self.counter += 1 + return i + + def reset(self) -> None: + self.counter = 0 diff --git a/megatron/core/inference_params.py b/megatron/core/inference_params.py index 287902460f..fa5273f3b0 100644 --- a/megatron/core/inference_params.py +++ b/megatron/core/inference_params.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" @@ -5,9 +6,12 @@ class InferenceParams: def __init__(self, max_batch_size, max_sequence_length): self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size + self.current_batch_size = max_batch_size # Required for bookkeeping variable-sized batches self.sequence_len_offset = 0 self.batch_size_offset = 0 + self.decode_mode = False self.key_value_memory_dict = {} + self.decode_mode = False def swap_key_value_dict(self, batch_idx): "swap between batches" @@ -25,3 +29,72 @@ def swap_key_value_dict(self, batch_idx): new_inference_key_memory, new_inference_value_memory, ) + + def enable_prefill_mode(self): + """ + Indicates the generation loop is in the prefill phase (still processing + input prompt tokens). This should be enabled if the generation loop is + encoding prompt tokens for *any* request in a batch. + """ + self.decode_mode = False + + def enable_decode_mode(self): + """ + Indicates the generation loop is in the decode phase (generating new output + tokens). This should only be enabled if the generation loop has fully encoded + the prompts for *all* requests in a batch. + """ + self.decode_mode = True + + def reset(self): + """Resets the inference state for a new batch.""" + self.current_batch_size = self.max_batch_size + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.enable_prefill_mode() + + def __str__(self): + return ( + f"InferenceParams(max_seq_len = {self.max_sequence_length}, " + f"max_batch_size = {self.max_batch_size}, " + f"current_batch_size = {self.current_batch_size}, " + f"sequence_len_offset = {self.sequence_len_offset}, " + f"batch_size_offset = {self.batch_size_offset}, " + f"key_value_memory_dict = {self.key_value_memory_dict.keys()})" + f"decode_mode = {self.decode_mode}" + ) + + def __eq__(self, other): + + if not isinstance(other, InferenceParams): + return False + + # Check all attributes match + basic_attrs = [ + 'max_sequence_length', + 'max_batch_size', + 'current_batch_size', + 'sequence_len_offset', + 'batch_size_offset', + ] + + if not all(hasattr(other, attr) for attr in basic_attrs): + return False + + # Check dictionary keys match; i.e. the same number of layers are cached + if self.key_value_memory_dict.keys() != other.key_value_memory_dict.keys(): + return False + + # Check each tensor tuple in the dictionary + for key in self.key_value_memory_dict: + self_tensors = self.key_value_memory_dict[key] + other_tensors = other.key_value_memory_dict[key] + + # Compare each key, value tensor in the tuple + for self_tensor, other_tensor in zip(self_tensors, other_tensors): + if ( + self_tensor.data_ptr() != other_tensor.data_ptr() + or self_tensor.shape != other_tensor.shape + ): + return False + return True diff --git a/megatron/core/jit.py b/megatron/core/jit.py new file mode 100644 index 0000000000..5b1dfff3e7 --- /dev/null +++ b/megatron/core/jit.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import torch + +from megatron.core.utils import is_torch_min_version + +jit_fuser = torch.jit.script +# nvFuser is deprecated in PyTorch JIT starting from 2.2 +if is_torch_min_version("2.2.0a0"): + jit_fuser = torch.compile diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 22d34da921..2d652e1ded 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -1,7 +1,7 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. from dataclasses import dataclass -from typing import Callable, Optional +from typing import Callable, ContextManager, Optional import torch @@ -10,201 +10,345 @@ class ModelParallelConfig: """Base configuration for Megatron Core - Model Parallelism - ----------------- - - tensor_model_parallel_size (int): Intra-layer model parallelism. Splits tensors across GPU ranks. Defaults to 1. - - context_parallel_size (int): Splits network input along sequence dimension across GPU ranks. Defaults to 1. - - pipeline_model_parallel_size (int): Inter-layer model parallelism. Splits transformer layers across GPU - ranks. Defaults to 1. - - virtual_pipeline_model_parallel_size (int): Interleaved pipeline parallelism is used to improve performance by - reducing the pipeline bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks. - The number of virtual blocks per pipeline model parallel rank is the virtual model parallel size. See Efficient - Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: https://arxiv.org/pdf/2104.04473.pdf for - more details. Defaults to None. - - sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by - parallelizing layer norms and dropout sequentially. See Reducing Activation Recomputation in Large Transformer - Models: https://arxiv.org/abs/2205.05198 for more details. Defaults to False. + The initialization function has an argument for each parameter. + """ - expert_model_parallel_size (int): Distributes Moe Experts across sub data parallel dimension. Defaults to False. + ################### + # Model parallelism + ################### + tensor_model_parallel_size: int = 1 + """Intra-layer model parallelism. Splits tensors across GPU ranks.""" - Initialization - -------------- + pipeline_model_parallel_comm_backend: Optional[str] = None + """Configuring backend option of pipeline parallel communication (e.g., nccl, ucc) + If None, the default backend will be used. + """ - perform_initialization (bool, default=True): If true, weights are initialized. This option can be useful when you - know you are going to load values from a checkpoint. + pipeline_model_parallel_size: int = 1 + """Inter-layer model parallelism. Splits transformer layers across GPU ranks.""" - use_cpu_initialization: (bool, default=False): When set to False, we initialize the weights directly on the GPU. - Transferring weights from CPU to GPU can take a significant amount of time for large models. Defaults to False. + virtual_pipeline_model_parallel_size: Optional[int] = None + """Interleaved pipeline parallelism is used to improve performance by reducing the pipeline + bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks. + The number of virtual blocks per pipeline model parallel rank is the virtual model parallel + size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: + arxiv.org/pdf/2104.04473.pdf for more details. + """ - Training - -------- + sequence_parallel: bool = False + """Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms + and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models + (https://arxiv.org/abs/2205.05198) for more details. + """ - fp16 (bool): If true, train with fp16 mixed precision training. Defaults to False. + context_parallel_size: int = 1 + """Splits network input along sequence dimension across GPU ranks.""" + + hierarchical_context_parallel_sizes: Optional[list[int]] = None + """Degrees of the hierarchical context parallelism. Users should provide a list to specify + the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains + groups of two levels, so the first value of the list indicates the group size of the a2a + communication type, and the second value indicates the group size of the p2p communication + type. + """ - bf16 (bool): If true, train with bf16 mixed precision training. Defaults to False. + expert_model_parallel_size: int = 1 + """Distributes Moe Experts across sub data parallel dimension.""" - params_dtype (torch.dtype): dtype used when intializing the weights. Defaults to torch.float32 + expert_tensor_parallel_size: Optional[int] = None + """Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks.""" - timers (optional, default=None): TODO + moe_extended_tp: bool = False + """NOTE: Deprecated from MCore v0.10. This flag is ignored. + Its functionality is replaced by expert_tensor_parallel_size. + """ - Optimizations - ------------- + ################### + # Initialization + ################### + perform_initialization: bool = True + """If true, weights are initialized. This option can be useful when you know you are going to + load values from a checkpoint. + """ - gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA - extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with - --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" - ". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion. - Defaults to False. + use_cpu_initialization: bool = False + """When set to False, we initialize the weights directly on the GPU. CPU initialization is the + same regardless of tensor model parallelism, but GPU initialization is not. Transferring + weights from CPU to GPU can take a significant amount of time for large models. + """ - async_tensor_model_parallel_allreduce (bool, default=True): If true, enables asynchronous execution of - tensor-model-parallel all-reduce with weight gradient compuation of a column-linear layer. Defaults to False. + ################### + # Training + ################### + fp16: bool = False + """If true, train with fp16 mixed precision training.""" - tp_comm_overlap (bool, default=False): If true, allows overlapping of Linear layer execution with tensor parallel - communication collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever possible - during the forward and the backward pass. Defaults to False. + bf16: bool = False + """If true, train with bf16 mixed precision training.""" - tp_comm_split_ag (bool, default=True): If true, allows All-Gather overlap with Fprop GEMM. Don't care if tp_comm_overlap - is False. + params_dtype: torch.dtype = torch.float32 + """dtype used when intializing the weights.""" - tp_comm_split_rs (bool, default=True): If true, allows Reduce-Scatter overlap with Fprop GEMM. Don't care if - tp_comm_overlap is False. + timers: Optional[Callable] = None + """Timers object to call for various timing functions. See megatron.core.timers.Timers""" - tp_comm_bulk_dgrad (bool, default=True): If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't - care if tp_comm_overlap is False. + finalize_model_grads_func: Optional[Callable] = None + """Function that finalizes gradients on all workers. Could include ensuring that grads are + all-reduced across data parallelism, pipeline parallelism, and sequence parallelism + dimensions. + """ - tp_comm_bulk_wgrad (bool, default=True): If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't - care if tp_comm_overlap is False. + grad_scale_func: Optional[Callable] = None + """If using loss scaling, this function should take the loss and return the scaled loss. If + None, no function is called on the loss. + """ - Parallelism - ----------- + no_sync_func: Optional[Callable] = None + """Function that creates a context that suppresses asynchronous data-parallel communication. If + the model is an instance of core.distributed.DistributedDataParallel, the default is to use + core.distributed.DistributedDataParallel.no_sync. + """ - finalize_model_grads_func (optional): Function that finalizes gradients on all workers. Could include ensuring that - grads are all-reduced across data parallelism, pipeline parallelism, and sequence parallelism dimensions. + grad_sync_func: Optional[Callable] = None + """Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient + reduce-scatters). The function should take one argument: an iterable of parameters whose + gradients are to be synchronized. + """ - Pipeline Parallelism - -------------------- + param_sync_func: Optional[Callable] = None + """Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer + parameter all-gathers). The function should take one argument: an iterable of parameters to + be synchronized. + """ - pipeline_dtype (required): dtype used in p2p communication, usually params_dtype + deterministic_mode: bool = False + """If true, code that has deterministic execution will be chosen. This usually + means slower execution, but is good for debugging and testing. Defaults to False.""" - grad_scale_func (optional, default=None): If using loss scaling, this function should take the loss and return the - scaled loss. If None, no function is called on the loss. + enable_autocast: bool = False + """If true runs the forward step function inside torch.autocast context.""" - enable_autocast (bool): If true runs the forward step function inside torch.autocast context. Default is False. + autocast_dtype: Optional[torch.dtype] = None + """dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype.""" - autocast_dtype (torch.dtype): dtype to pass to torch.amp.autocast when enabled. Default is pipeline_dtype. - - variable_seq_lengths (bool, default=False): Support for variable sequence lengths across microbatches. Setting this - communicates the size of tensors during pipeline parallelism communication, because of this extra overhead it - should only be set if the sequence length varies by microbatch within a global batch. + num_microbatches_with_partial_activation_checkpoints: Optional[int] = None + """If int, set the number of microbatches where not all of the layers will be checkpointed and + recomputed. The rest of the microbatches within the window of maximum outstanding + microbatches will recompute all layers (either full recompute or selective recompute). If + None, the checkpoint and recompute will be left up to the forward_step function. - num_microbatches_with_partial_activation_checkpoints (int, default=None): If int, set the number of microbatches - where not all of the layers will be checkpointed and recomputed. The rest of the microbatches within the window - of maximum outstanding microbatches will recompute all layers (either full recompute or selective recompute). If - None, the checkpoint and recompute will be left up to the forward_step function. + """ - overlap_p2p_comm (bool, optional, default=False): When True some of the peer to peer communication for pipeline - parallelism will overlap with computation. Must be False if batch_p2p_comm is true. + ################### + # Optimizations + ################### + gradient_accumulation_fusion: bool = False + """If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension + fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install + APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" + --global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you + must turn off gradient accumulation fusion. + """ - batch_p2p_comm (bool, default=True): Use batch_isend_irecv instead of individual isend/irecv calls. Must be False - if overlap_p2p_comm is True. + async_tensor_model_parallel_allreduce: bool = False + """NOTE: Deprecated. This flag is ignored.""" - batch_p2p_sync (bool, default=True): When using batch_isend_irecv, do a cuda.device.synchronize afterward to work - around a bug in older version of PyTorch. + use_te_rng_tracker: bool = False + """If true, uses RNG state tracker in TransformerEngine if exists. + """ - use_ring_exchange_p2p (bool, default=False): Use custom ring_exchange kernel instead of - torch.distributed.batch_isend_irecv(). Requires custom built torch with torch.distributed.ring_exchange. + tp_comm_overlap: bool = False + """If true, allows overlapping of Linear layer execution with tensor parallel communication + collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever + possible during the forward and the backward pass. + """ - deallocate_pipeline_outputs (optional, default=False): If True, output data is deallocated after the tensor is sent - to the next pipeline stage. Helps with saving memory, does nothing when pipeline parallel is not used. + tp_comm_bulk_wgrad: bool = True + """If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if + tp_comm_overlap is False. + """ - no_sync_func (optional): Function that creates a context that suppresses asynchronous data-parallel - communication. If the model is an instance of core.distributed.DistributedDataParallel, the default is to use - core.distributed.DistributedDataParallel.no_sync. + tp_comm_bulk_dgrad: bool = True + """If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if + tp_comm_overlap is False. + """ - grad_sync_func (optional): Function that launches asynchronous gradient reductions (e.g. distributed optimizer - gradient reduce-scatters). The function should take one argument: an iterable of parameters whose gradients are - to be synchronized. + tp_comm_overlap_ag: bool = True + """If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather. + Don't care if tp_comm_overlap is False. + """ - param_sync_func (optional): Function that launches asynchronous parameter synchronizations (e.g. distributed - optimizer parameter all-gathers). The function should take one argument: an iterable of parameters to be - synchronized. + tp_comm_overlap_rs: bool = True + """If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter. + Don't care if tp_comm_overlap is False. + """ - pipeline_model_parallel_split_rank (int, default=None): If int, rank where encoder and decoder should be split in - cases where the model has both an encoder and decoder (e.g., T5). Ignored if None. + tp_comm_overlap_rs_dgrad: bool = False + """If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the + GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False. + """ - barrier_with_L1_time (bool, default=True): If true, use barrier with level 1 time measurements. It is up to the user - to make sure calling barrier with their timers will not result in hangs. This can happen if for example the user - adds a level 1 timer that is not called by all ranks. + tp_comm_split_ag: bool = True + """Deprecated from TransformerEngine v1.6.0. + If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather + splits. Don't care if tp_comm_overlap is False. + """ + tp_comm_atomic_ag: bool = False + """Deprecated from TransformerEngine v1.6.0. + If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather + both done atomically. Don't care if tp_comm_overlap is False. """ - # Model parallelism - tensor_model_parallel_size: int = 1 - context_parallel_size: int = 1 - pipeline_model_parallel_size: int = 1 - virtual_pipeline_model_parallel_size: Optional[int] = None - sequence_parallel: bool = False - expert_model_parallel_size: int = 1 + tp_comm_split_rs: bool = True + """Deprecated from TransformerEngine v1.6.0. + If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and + Reduce-Scatter splits. Don't care if tp_comm_overlap is False. + """ - # Initialization - perform_initialization: bool = True - use_cpu_initialization: bool = False + tp_comm_atomic_rs: bool = False + """Deprecated from TransformerEngine v1.6.0. + If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and + Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False. + """ - # Training - fp16: bool = False - bf16: bool = False - params_dtype: torch.dtype = torch.float32 - timers: Callable = None + cross_entropy_loss_fusion: bool = False + """If this is enabled, the fused cross entropy implementation would be used. + Defaults to False. + """ - # Optimizations - gradient_accumulation_fusion: bool = False - async_tensor_model_parallel_allreduce: bool = False - tp_comm_overlap: bool = False + tp_comm_overlap_disable_qkv: bool = False + """ + If true, the AllGather -> Gemm overlap for QKV gets disabled + """ - # Debug Options - tp_comm_split_ag: bool = True - tp_comm_split_rs: bool = True - tp_comm_bulk_wgrad: bool = True - tp_comm_bulk_dgrad: bool = True + tp_comm_overlap_disable_fc1: bool = False + """ + If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled + """ - # Parallelism - finalize_model_grads_func: Callable = None + tp_comm_bootstrap_backend: str = 'nccl' + """ + Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo' + """ + ################### # Pipeline Parallel + ################### pipeline_dtype: torch.dtype = None - grad_scale_func: Callable = None - enable_autocast: bool = False - autocast_dtype: torch.dtype = None + """dtype used in p2p communication, usually params_dtype""" + variable_seq_lengths: bool = False - num_microbatches_with_partial_activation_checkpoints: Optional[int] = None + """Support for variable sequence lengths across microbatches. Setting this communicates the size + of tensors during pipeline parallelism communication, because of this extra overhead it + should only be set if the sequence length varies by microbatch within a global batch. + """ + overlap_p2p_comm: bool = False + """When True some of the peer to peer communication for pipeline parallelism will overlap with + computation. Must be False if batch_p2p_comm is true. + """ + batch_p2p_comm: bool = True + """Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if + overlap_p2p_comm is True. + """ + batch_p2p_sync: bool = True + """When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in + older version of PyTorch. + """ + use_ring_exchange_p2p: bool = False + """Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires + custom built torch with torch.distributed.ring_exchange. + """ + deallocate_pipeline_outputs: bool = False - no_sync_func: Callable = None - grad_sync_func: Callable = None - param_sync_func: Callable = None + """If True, output data is deallocated after the tensor is sent to the next pipeline stage. + Helps with saving memory, does nothing when pipeline parallel is not used. + """ + + defer_embedding_wgrad_compute: bool = False + """If true, defers the embedding WGRAD GEMMs while pipeline flush is + taking place enabling us to hide pipeline flush latency. Defaults to False. + """ + + wgrad_deferral_limit: int = 0 + """This value tunes the number of micro-batches for which the embedding weight gradient compute + needs to be deferred to pipeline flush, this argument is invalid if + `defer_embedding_wgrad_compute` is False. + Defaults to 0, which means all micro-batches are deferred. + """ + pipeline_model_parallel_split_rank: Optional[int] = None + """If int, rank where encoder and decoder should be split in cases where the model has both an + encoder and decoder (e.g., T5). Ignored if None. + """ + + overlap_p2p_comm_warmup_flush: bool = False + """If true, overlap communication and computation in warm up and flush phase. + Only valid when overlap_p2p_comm is True and batch_p2p_comm is False. + Defaults to False. + """ + + microbatch_group_size_per_vp_stage: Optional[int] = None + """This value specifies the number of micro-batches that are executed + at a time for a given virtual stage (both forward and backward). + Default (in __post_init__() method below) to pipeline_parallel_size + which specifies a depth-first schedule. + Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2, + num_microbatches = 4, we have + rank 0 | 0 1 0 1 2 3 2 3 + rank 1 | 0 1 0 1 2 3 2 3 + When microbatch_group_size_per_vp_stage=3, num_microbatches = 5, + we have + rank 0 | 0 1 2 0 1 2 3 4 3 4 + rank 1 | 0 1 2 0 1 2 3 4 3 4 + """ + + ################### + # CPU Offloading + ################### + cpu_offloading: bool = False + """When set to True, all the activations are offloaded to the CPU asynchronously.""" + + cpu_offloading_num_layers: int = 0 + """Tells the number of transformer layers for which activations has to be offloaded.""" + + _cpu_offloading_context: Optional[ContextManager] = ( + None + # Used for internal use only, not to be set by a user. + # TODO: Need to move to the 'right' place when possible. + ) + """For internal use only, do not set.""" + cpu_offloading_activations: bool = True + """If True, offloads the activations to CPU.""" + + cpu_offloading_weights: bool = True + """If True, offloads the weights to CPU.""" + + ################### # Timing + ################### barrier_with_L1_time: bool = True + """If true, use barrier with level 1 time measurements. It is up to the user to make sure + calling barrier with their timers will not result in hangs. This can happen if for example + the user adds a level 1 timer that is not called by all ranks. + """ def __post_init__(self): - """ Python dataclass method that is used to modify attributes after initialization. - See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details. + """Python dataclass method that is used to modify attributes after initialization. + See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more + details. """ if self.sequence_parallel: if self.tensor_model_parallel_size <= 1: raise ValueError("Can not use sequence paralllelism without tensor parallelism") - if self.async_tensor_model_parallel_allreduce: - # sequence_parallelism already does this async - self.async_tensor_model_parallel_allreduce = False + + if self.expert_tensor_parallel_size is None: + self.expert_tensor_parallel_size = self.tensor_model_parallel_size if self.pipeline_model_parallel_size > 1: if self.pipeline_dtype is None: @@ -215,8 +359,34 @@ def __post_init__(self): if self.autocast_dtype is None: self.autocast_dtype = self.params_dtype + if self.defer_embedding_wgrad_compute and self.pipeline_model_parallel_size == 1: + raise ValueError( + "Cannot defer embedding wgrad compute when pipeline model parallel is not used" + ) + + if self.defer_embedding_wgrad_compute and not self.gradient_accumulation_fusion: + raise ValueError( + "Cannot defer embedding wgrad compute when gradient accumulation fusion is not used" + ) + + if self.defer_embedding_wgrad_compute and self.wgrad_deferral_limit < 0: + raise ValueError( + "Wgrad deferral limit should be greater than or equal to 0 when it is enabled!" + ) + if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1: if self.sequence_parallel is False: raise ValueError( - "When using expert parallelism and tensor parallelism, sequence parallelism must be used" + "When using expert parallelism and tensor parallelism, " + "sequence parallelism must be used" + ) + + if self.microbatch_group_size_per_vp_stage is None: + self.microbatch_group_size_per_vp_stage = self.pipeline_model_parallel_size + + if self.overlap_p2p_comm_warmup_flush: + if not self.overlap_p2p_comm or self.batch_p2p_comm: + raise ValueError( + "Pipeline parallel communication overlapping in warmup and flush is only " + "compatible with overlap_p2p_comm but not batch_p2p_comm." ) diff --git a/megatron/core/models/T5/__init__.py b/megatron/core/models/T5/__init__.py new file mode 100644 index 0000000000..2551f81e65 --- /dev/null +++ b/megatron/core/models/T5/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from .t5_model import T5Model diff --git a/megatron/core/models/T5/t5_model.py b/megatron/core/models/T5/t5_model.py new file mode 100644 index 0000000000..68335591df --- /dev/null +++ b/megatron/core/models/T5/t5_model.py @@ -0,0 +1,517 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from typing import List, Literal, Optional, Tuple + +import torch +from torch import Tensor + +from megatron.core import InferenceParams, parallel_state, tensor_parallel +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.enums import ModelType +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.relative_pos_embedding import RelativePositionEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig + + +class T5LMHead(MegatronModule): + """Masked LM head for T5 + + Args: + config (TransformerConfig): transformer config + parallel_output (bool): wether output logits being distributed or not. + vocab_size (int): vocabulary size + pre_process (bool): Include embedding layer + share_embeddings_and_output_weights (bool): When True, input + embeddings and output logit weights are shared. + """ + + def __init__( + self, + config: TransformerConfig, + parallel_output: bool, + vocab_size: int, + pre_process: bool = True, + share_embeddings_and_output_weights: bool = False, + ): + super(T5LMHead, self).__init__(config=config) + + if has_config_logger_enabled(config): + log_config_to_disk(config, locals(), prefix=type(self).__name__) + + self.parallel_output = parallel_output + + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + vocab_size, + config=config, + init_method=config.init_method, + bias=share_embeddings_and_output_weights, + skip_bias_add=not share_embeddings_and_output_weights, + gather_output=not self.parallel_output, + skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights, + ) + + def forward(self, hidden_states: Tensor, word_embeddings_weight: Tensor) -> Tensor: + """Forward pass. + + Args: + hidden_states (Tensor): output hidden states from decoder + word_embeddings_weight (Tensor): word embedding weight + + Returns: + Tensor: logits tensor + """ + + logits, _ = self.output_layer(hidden_states, weight=word_embeddings_weight) + return logits + + +class T5Model(LanguageModule): + """T5 Language model. + + Args: + config (TransformerConfig): transformer config + + encoder_config (TransformerConfig): encoder transformer config + + transformer_encoder_layer_spec (ModuleSpec): transformer layer + customization specs for encoder + + transformer_decoder_layer_spec (ModuleSpec): transformer layer + customization specs for decoder + + vocab_size (int): vocabulary size + + max_sequence_length (int): maximum size of sequence. This is used for positional embedding + + pre_process (bool): Include embedding layer (used with pipeline parallelism) + + post_process (bool): Include an output layer (used with pipeline parallelism) + + fp16_lm_cross_entropy (bool, optional): Defaults to False + + parallel_output (bool): Do not gather the outputs, + keep them split across tensor parallel ranks + + share_embeddings_and_output_weights (bool): When True, + input embeddings and output logit weights are shared. Defaults to False. + + position_embedding_type (string): Position embedding type. + Options ['learned_absolute', 'rope']. + Defaults is 'learned_absolute'. + + rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. + Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. + + seq_len_interpolation_factor (float): scale of linearly interpolating + RoPE for longer sequences. The value must be a float larger than 1.0. + Defaults to None. + + add_encoder (bool): Create the encoder (used with pipeline parallelism). + When using pipelining, the encoder will only be created on a subset + of the pipeline ranks. + + add_decoder (bool): Include an output layer (used with pipeline parallelism). + As with `add_encoder`, when using this model and pipelining, + the decoder will only be created on a subset of the pipeline ranks. + """ + + def __init__( + self, + config: TransformerConfig, + encoder_config: TransformerConfig, + transformer_encoder_layer_spec: ModuleSpec, + transformer_decoder_layer_spec: ModuleSpec, + vocab_size: int, + max_sequence_length: int, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + position_embedding_type: Literal[ + 'learned_absolute', 'rope', 'relative' + ] = 'learned_absolute', + rotary_percent: float = 1.0, + seq_len_interpolation_factor: Optional[float] = None, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + add_encoder: bool = True, + add_decoder: bool = True, + ): + + super(T5Model, self).__init__(config=config) + + self.config: TransformerConfig = config + self.encoder_config: TransformerConfig = encoder_config + self.transformer_encoder_layer_spec: ModuleSpec = transformer_encoder_layer_spec + self.transformer_decoder_layer_spec: ModuleSpec = transformer_decoder_layer_spec + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.position_embedding_type = position_embedding_type + self.encoder_hidden_state = None + + self.model_type = ModelType.encoder_and_decoder + + # Tells schedules.py that this model has a skip connection + # between the encoder's output and the decoder + # (and hence both the encoder and decoder's tensors are required for correct backprop). + self.xattn_needed = True + + # specify the position embeddings as a member + # variable in the T5 class so that they are easy to + # find for `finalize_model_grads._allreduce_position_embedding_grads` + self.position_embeddings = None + if self.pre_process: + self.embedding = LanguageModelEmbedding( + config=self.config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + position_embedding_type=self.position_embedding_type, + ) + if position_embedding_type == "learned_absolute": + self.position_embeddings = self.embedding.position_embeddings + else: + self.position_embeddings = None + + # Rotary Position Embeddings + if self.position_embedding_type == 'rope': + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + use_cpu_initialization=self.config.use_cpu_initialization, + ) + + # Relative Position Embeddings + if self.position_embedding_type == 'relative': + self.encoder_relative_pos_emb = RelativePositionEmbedding( + bidirectional=True, + init_method=self.config.init_method, + num_attention_heads=self.config.num_attention_heads, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + ) + self.decoder_relative_pos_emb = RelativePositionEmbedding( + bidirectional=False, + init_method=self.config.init_method, + num_attention_heads=self.config.num_attention_heads, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + ) + + # Transformer encoder + encoder_spec, decoder_spec = ( + self.transformer_encoder_layer_spec, + self.transformer_decoder_layer_spec, + ) + if self.add_encoder: + self.encoder = TransformerBlock( + config=self.encoder_config, + spec=encoder_spec, + pre_process=self.pre_process, + post_process=self.post_process, + ) + else: + self.encoder = None + + if self.add_decoder: + # Transformer decoder + self.decoder = TransformerBlock( + config=self.config, + spec=decoder_spec, + pre_process=self.pre_process, + post_process=self.post_process, + ) + else: + self.decoder = None + + # Output + if post_process: + self.lm_head = T5LMHead( + config, + parallel_output, + self.vocab_size, + self.pre_process, + self.share_embeddings_and_output_weights, + ) + self.output_layer = self.lm_head.output_layer + + if self.pre_process or self.post_process: + self.setup_embeddings_and_output_layer() + + def forward( + self, + encoder_input_ids: Tensor, + decoder_input_ids: Tensor, + encoder_attn_mask: Tensor, + decoder_attn_mask: Tensor, + encoder_decoder_attn_mask: Tensor, + lm_labels: Tensor = None, + encoder_hidden_states: Tensor = None, + output_encoder_hidden_only: bool = False, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + ) -> Tensor: + """Forward pass. + + Args: + encoder_input_ids (Tensor): input ids for encoder + decoder_input_ids (Tensor): input ids for decoder + encoder_attn_mask (Tensor): self-attention mask for encoder + decoder_attn_mask (Tensor): self-attention mask for decoder + encoder_decoder_attn_mask (Tensor): cross-attention mask between encoder and decoder + lm_labels (Tensor): labels for decoder output + inference_params (InferenceParams): relevant arguments for inferencing + + Returns: + Tensor: loss tensor + """ + + ## Encoder forward + if encoder_hidden_states is None: + + # Encoder position ids + encoder_position_ids = t5_position_ids(encoder_input_ids) + + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding( + input_ids=encoder_input_ids, position_ids=encoder_position_ids + ) + else: + # intermediate stage of pipeline + encoder_input = None + + # Rotary positional embeddings + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.encoder, encoder_input, self.config, packed_seq_params + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # Relative positional embeddings + encoder_attention_bias_parallel = None + if self.position_embedding_type == 'relative': + query_seq_length = RelativePositionEmbedding.get_relative_seq_len( + inference_params, self.encoder, encoder_input, self.config + ) + key_seq_length = query_seq_length + attention_bias = self.encoder_relative_pos_emb(query_seq_length, key_seq_length) + + # Scatter attention_bias to TP ranks + # First, reshape [1, num_head, seqlen_q, seqlen_kv] to + # [1, seqlen_q, seqlen_kv, num_head] to be scatter along + # the last (num_heads dimension) + attention_bias = torch.permute(attention_bias, (0, 2, 3, 1)) + # Then, scatter to TP region + attention_bias_parallel = scatter_to_tensor_model_parallel_region(attention_bias) + # Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv] + encoder_attention_bias_parallel = torch.permute( + attention_bias_parallel, (0, 3, 1, 2) + ) + + # Run encoder. + if self.add_encoder: + encoder_hidden_states = self.encoder( + hidden_states=encoder_input, + attention_mask=encoder_attn_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + attention_bias=encoder_attention_bias_parallel, + ) + else: + encoder_hidden_states = self.encoder_hidden_state + + if not self.add_decoder or output_encoder_hidden_only: + return encoder_hidden_states + + ## Decoder forward + # Decoder position ids + decoder_position_ids = t5_position_ids(decoder_input_ids) + + # Decoder embedding. + if self.pre_process: + decoder_input = self.embedding( + input_ids=decoder_input_ids, position_ids=decoder_position_ids + ) + else: + # intermediate stage of pipeline + decoder_input = None ### should it take encoder_hidden_states + + # Rotary positional embeddings + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config, packed_seq_params + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # Relative positional embeddings + decoder_attention_bias_parallel = None + if self.position_embedding_type == 'relative': + query_seq_length = RelativePositionEmbedding.get_relative_seq_len( + inference_params, self.decoder, decoder_input, self.config + ) + key_seq_length = query_seq_length + attention_bias = self.decoder_relative_pos_emb(query_seq_length, key_seq_length) + + # Scatter attention_bias to TP ranks + # First, reshape [1, num_head, seqlen_q, seqlen_kv] to + # [1, seqlen_q, seqlen_kv, num_head] to be scatter along + # the last (num_heads dimension) + attention_bias = torch.permute(attention_bias, (0, 2, 3, 1)) + # Then, scatter to TP region + attention_bias_parallel = scatter_to_tensor_model_parallel_region(attention_bias) + # Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv] + decoder_attention_bias_parallel = torch.permute(attention_bias_parallel, (0, 3, 1, 2)) + + # Run decoder. + decoder_hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=decoder_attn_mask, + context=encoder_hidden_states, + context_mask=encoder_decoder_attn_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + attention_bias=decoder_attention_bias_parallel, + ) + + if self.post_process: + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + lm_logits = self.lm_head(decoder_hidden_states, word_embeddings_weight=output_weight) + + if lm_labels is None: + # [s b h] => [b s h] + return lm_logits.transpose(0, 1).contiguous() + else: + # [b s] => [s b] + lm_loss = self.compute_language_model_loss(lm_labels, lm_logits) + return lm_loss + else: + return decoder_hidden_states + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + if self.add_encoder and self.add_decoder: + assert ( + len(input_tensor) == 1 + ), 'input_tensor should only be length 1 for stage with both encoder and decoder' + self.encoder.set_input_tensor(input_tensor[0]) + elif self.add_encoder: + assert ( + len(input_tensor) == 1 + ), 'input_tensor should only be length 1 for stage with only encoder' + self.encoder.set_input_tensor(input_tensor[0]) + elif self.add_decoder: + if len(input_tensor) == 2: + self.decoder.set_input_tensor(input_tensor[0]) + self.encoder_hidden_state = input_tensor[1] + elif len(input_tensor) == 1: + self.decoder.set_input_tensor(None) + self.encoder_hidden_state = input_tensor[0] + else: + raise Exception('input_tensor must have either length 1 or 2') + else: + raise Exception('Stage must have at least either encoder or decoder') + + def shared_embedding_or_output_weight(self) -> Tensor: + """Function to share the input embeddings and output logit weights.""" + + if self.pre_process: + return self.embedding.word_embeddings.weight + elif self.post_process: + return self.lm_head.output_layer.weight + return None + + def sharded_state_dict( + self, + prefix: str = '', + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, + ) -> ShardedStateDict: + """Sharded state dict implementation handling duplication of encoder and decoder layers. + + Some layers (output, embedding) are shared between the encoder and decoder. + This method sets the replica_id for them to ensure there is only one + layer instance with replica_id (0, 0, 0). + + Args: + prefix (str): Module name prefix. + sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. + metadata (Optional[Dict]): metadata controlling sharded state dict creation. + + Returns: + ShardedStateDict: sharded state dict for the T5Model + """ + sharded_sd = super().sharded_state_dict(prefix, sharded_offsets, metadata) + if not parallel_state.is_inside_encoder(): + for k, sh_ten in sharded_sd.items(): + if not k.startswith(f'{prefix}decoder'): + # Bump replica_id of all the layers shared with the encoder (output, embedding) + sh_ten.replica_id = (sh_ten.replica_id[0] + 1, *sh_ten.replica_id[1:]) + return sharded_sd + + +def t5_extended_attention_mask(attention_mask_list: List[Tensor]) -> List[Tensor]: + """Creates the extended attention mask + + Converts the attention mask of dimension [batch size, seq_len, seq_len] + to [batch size, 1, seq_len, seq_len] + + Args: + attention_mask (Tensor): The input attention mask + + Returns: + Tensor: The extended binary attention mask + """ + + def attn_mask_postprocess(attn_mask): + # [b, 1, s, s] + extended_attention_mask = attn_mask.unsqueeze(1) + return extended_attention_mask + + return [ + (attn_mask_postprocess(attn_mask) if attn_mask is not None else None) + for attn_mask in attention_mask_list + ] + + +def t5_position_ids(token_ids: Tensor) -> Tensor: + """Calculate position ids from token ids + Args: + token_ids (Tensor): input tokens + + Returns: + Tensor: position ids + """ + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids diff --git a/megatron/core/models/T5/t5_spec.py b/megatron/core/models/T5/t5_spec.py new file mode 100644 index 0000000000..8370b07df1 --- /dev/null +++ b/megatron/core/models/T5/t5_spec.py @@ -0,0 +1,248 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import ( + CrossAttention, + CrossAttentionSubmodules, + SelfAttention, + SelfAttentionSubmodules, +) +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlockSubmodules +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + import apex # pylint: disable=unused-import + + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + + HAVE_APEX = True + LNImpl = FusedLayerNorm +except ImportError: + import warnings + + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + warnings.warn(f'Apex is not installed. Falling back to Torch Norm') + LNImpl = WrappedTorchNorm + + +def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec: + """T5 encoder TE spec (uses Transformer Engine components).""" + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.padding}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ) + + +def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec: + """T5 decoder TE spec (uses Transformer Engine components).""" + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_cross_attn_layernorm=TENorm, + cross_attention=ModuleSpec( + module=CrossAttention, + params={"attn_mask_type": AttnMaskType.padding}, + submodules=CrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + cross_attn_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ) + + +def encoder_model_with_local_spec() -> ModuleSpec: + """T5 encoder local spec (uses Megatron-Core components).""" + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=LNImpl, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.arbitrary}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=LNImpl, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map={ + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + }, + ), + ) + + +def decoder_model_with_local_spec() -> ModuleSpec: + """T5 decoder local spec (uses Megatron-Core components).""" + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=LNImpl, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_cross_attn_layernorm=LNImpl, + cross_attention=ModuleSpec( + module=CrossAttention, + params={"attn_mask_type": AttnMaskType.arbitrary}, + submodules=CrossAttentionSubmodules( + linear_q=ColumnParallelLinear, + linear_kv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + ), + ), + cross_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=LNImpl, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map={ + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + }, + ), + ) + + +def get_t5_encoder_with_transformer_engine_block_spec( + num_layers: int, +) -> TransformerBlockSubmodules: + """T5 encoder block spec for Transformer Engine + + Args: + config (TransformerConfig): config, containing number of layers for encoder + """ + + layer_spec = encoder_model_with_transformer_engine_default_spec() + block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm) + return block_spec + + +def get_t5_decoder_with_transformer_engine_block_spec( + num_layers: int, +) -> TransformerBlockSubmodules: + """T5 decoder block spec for Transformer Engine + + Args: + config (TransformerConfig): config, containing number of layers for decoder + """ + + layer_spec = decoder_model_with_transformer_engine_default_spec() + block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm) + return block_spec + + +def get_t5_encoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules: + """T5 encoder block spec for local (uses Megatron-Core components) + + Args: + num_layers (int): number of encoder layers + """ + + layer_spec = encoder_model_with_local_spec() + block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm) + return block_spec + + +def get_t5_decoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules: + """T5 decoder block spec for local (uses Megatron-Core components) + + Args: + num_layers (int): number of decoder layers + """ + + layer_spec = decoder_model_with_local_spec() + block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm) + return block_spec diff --git a/megatron/fused_kernels/tests/__init__.py b/megatron/core/models/bert/__init__.py similarity index 100% rename from megatron/fused_kernels/tests/__init__.py rename to megatron/core/models/bert/__init__.py diff --git a/megatron/core/models/bert/bert_layer_specs.py b/megatron/core/models/bert/bert_layer_specs.py new file mode 100644 index 0000000000..4edc2ed628 --- /dev/null +++ b/megatron/core/models/bert/bert_layer_specs.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import warnings + +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +try: + from megatron.core.extensions.transformer_engine import ( + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + import apex # pylint: disable=unused-import + + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + + HAVE_APEX = True + LNImpl = FusedLayerNorm +except ImportError: + + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + warnings.warn('Apex is not installed. Falling back to Torch Norm') + LNImpl = WrappedTorchNorm + + +def get_bert_layer_with_transformer_engine_spec(): + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + + Returns: + ModuleSpec: Module specification with TE modules + """ + if not HAVE_TE: + raise ImportError( + "Transformer Engine is not installed. Please use local Bert layer spec instead." + ) + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.padding}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ) + + +def __getattr__(name): + if name == 'bert_layer_with_transformer_engine_spec': + warnings.warn( + """Attribute bert_layer_specs.bert_layer_with_transformer_engine_spec is on a + deprecation track and will be removed in future releases. Please migrate to + bert_layer_specs.get_bert_layer_with_transformer_engine_spec().""" + ) + + return get_bert_layer_with_transformer_engine_spec() + + +# Use this spec for an implementation using only modules in megatron core +bert_layer_local_spec = ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=LNImpl, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.padding}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=LNImpl, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear), + ), + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map={ + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + }, + ), +) diff --git a/megatron/core/models/bert/bert_lm_head.py b/megatron/core/models/bert/bert_lm_head.py new file mode 100644 index 0000000000..9002eab978 --- /dev/null +++ b/megatron/core/models/bert/bert_lm_head.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import torch +from torch import Tensor + +from megatron.core.fusions.fused_layer_norm import HAVE_FUSED_LAYER_NORM, FusedLayerNorm +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import get_linear_layer + +if HAVE_FUSED_LAYER_NORM: + LNImpl = FusedLayerNorm +else: + import warnings + + warnings.warn(f'Apex is not installed. Falling back to Torch Norm') + from megatron.core.transformer.torch_norm import WrappedTorchNorm as LNImpl + + +class BertLMHead(MegatronModule): + """Masked LM head for Bert. + + Args: + hidden_size: hidden size + config (TransformerConfig): TransformerConfig object + """ + + def __init__(self, hidden_size: int, config: TransformerConfig): + super().__init__(config=config) + + # TODO: Should switch this to TE ? + self.dense = get_linear_layer( + hidden_size, hidden_size, config.init_method, config.perform_initialization + ) + + setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel) + setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel) + + self.layer_norm = LNImpl( + config=config, hidden_size=hidden_size, eps=config.layernorm_epsilon + ) + + self.gelu = torch.nn.functional.gelu + + def forward(self, hidden_states: Tensor) -> Tensor: + """forward pass""" + + hidden_states = self.dense(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = self.layer_norm(hidden_states) + return hidden_states diff --git a/megatron/core/models/bert/bert_model.py b/megatron/core/models/bert/bert_model.py new file mode 100644 index 0000000000..1c3684c04b --- /dev/null +++ b/megatron/core/models/bert/bert_model.py @@ -0,0 +1,373 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import warnings +from typing import Literal, Optional + +import torch +from torch import Tensor + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.models.bert.bert_lm_head import BertLMHead +from megatron.core.models.bert.pooler import Pooler +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.transformer.dot_product_attention import ( + DotProductAttention as MCoreDotProductAttention, +) +from megatron.core.transformer.enums import AttnBackend, AttnMaskType, ModelType +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import get_linear_layer +from megatron.core.utils import get_te_version as _get_te_version +from megatron.core.utils import is_te_min_version + + +def get_te_version(): + """Included for backwards compatibility.""" + warnings.warn("`get_te_version` will be deprecated in a future release") + return _get_te_version() + + +class BertModel(LanguageModule): + """Transformer language model. + + Args: + config (TransformerConfig): transformer config + num_tokentypes (int) : Set to 2 when args.bert_binary_head is True, and 0 otherwise. + Defaults to 0. + transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers + vocab_size (int): vocabulary size + max_sequence_length (int): maximum size of sequence. This is used for positional embedding + pre_process (bool): Include embedding layer (used with pipeline parallelism) + post_process (bool): Include an output layer (used with pipeline parallelism) + parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel + ranks + share_embeddings_and_output_weights (bool): When True, input embeddings and output logit + weights are shared. Defaults to False. + position_embedding_type (string): Position embedding type. + Options ['learned_absolute', 'rope']. Defaults is 'learned_absolute'. + rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. + Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. + """ + + def __init__( + self, + config: TransformerConfig, + num_tokentypes: int, + transformer_layer_spec: ModuleSpec, + vocab_size: int, + max_sequence_length: int, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute', + rotary_percent: float = 1.0, + seq_len_interpolation_factor: Optional[float] = None, + add_binary_head=True, + return_embeddings=False, + ): + super(BertModel, self).__init__(config=config) + + if has_config_logger_enabled(config): + log_config_to_disk(config, locals(), prefix=type(self).__name__) + + if return_embeddings: + assert self.post_process and self.add_binary_head + + self.config: TransformerConfig = config + self.transformer_layer_spec: ModuleSpec = transformer_layer_spec + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.position_embedding_type = position_embedding_type + self.add_binary_head = add_binary_head + self.return_embeddings = return_embeddings + + # megatron core pipelining currently depends on model type + self.model_type = ModelType.encoder_or_decoder + + self.attn_mask_dimensions = self._sanity_check_attention_and_get_attn_mask_dimension() + + # Embeddings. + if self.pre_process: + self.embedding = LanguageModelEmbedding( + config=self.config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + position_embedding_type=position_embedding_type, + num_tokentypes=num_tokentypes, + ) + + if self.position_embedding_type == 'rope': + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + use_cpu_initialization=self.config.use_cpu_initialization, + ) + + # Transformer. + self.encoder = TransformerBlock( + config=self.config, + spec=self.transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + ) + + # Output + if post_process: + # TODO: Make sure you are passing in the mpu_vocab_size properly + self.lm_head = BertLMHead(config.hidden_size, config) + + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=True, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights, + ) + + self.binary_head = None + if self.add_binary_head: + # TODO: Shoudl switch this to TE ? + self.binary_head = get_linear_layer( + config.hidden_size, 2, config.init_method, config.perform_initialization + ) + + self.pooler = Pooler( + config.hidden_size, config.init_method, config, config.sequence_parallel + ) + + if self.pre_process or self.post_process: + self.setup_embeddings_and_output_layer() + + # pylint: disable=line-too-long + def _sanity_check_attention_and_get_attn_mask_dimension(self) -> str: + """We do some checks and return attention mask dimensions for self attention + + Transformer engine library underwent a lot of change. So we need to change dimensions of + the attention mask depending on the TE version. We also santiy check some arguments. + + 1. If we use local version of attention dimension of the mask is [b,1,s,s] + 2. If we use transformer engine > 1.10 we support all 3 backends with padding mask and [b,1,s,s] + 3. If we use transformer engine >= 1.7 but less than 1.10 + a ) Flash and Fused attention uses padding mask with [b,1,1,s] + b ) Unfused attention works with arbitrary mask with [b,1,s,s] + 4. If we use transformer engine < 1.7 + Flash and fused attention is not supported. Unfused attention will work with padding mask [b,1,s,s] + + Default if you dont set any NVTE_ATTN flag will it will just use the fused path for transformer engine version >= 1.7 and unfused path for other + + Args: + transformer_layer_spec (ModuleSpec): The transformer layer spec + + Returns: + str: A string showing the format of the attn mask dimensions + """ + attention_backend = self.config.attention_backend + attn_mask_dimensions = None + # For local layer spec we just use b1ss + if ( + self.transformer_layer_spec.submodules.self_attention.submodules.core_attention + == MCoreDotProductAttention + ): + assert attention_backend in [ + AttnBackend.local, + AttnBackend.auto, + ], f'Expected AttnBackend to be local or auto while using mcore self attention, but found {attention_backend}. Set --attn-backend to local or dont use MCore SelfAttention submodule in layer specs' + attn_mask_dimensions = "b1ss" + else: + attn_mask_type = self.transformer_layer_spec.submodules.self_attention.params[ + 'attn_mask_type' + ] + # For TE >= 1.10 (We always use padding mask and use b11s) + if is_te_min_version("1.10.0"): + attn_mask_dimensions = "b11s" + if attn_mask_type != AttnMaskType.padding: + warnings.warn( + f'For TE versions >= 1.10 , flash/fused/unfused support padding mask. Setting attention mask from {attn_mask_type} to padding' + ) + self.transformer_layer_spec.submodules.self_attention.params[ + 'attn_mask_type' + ] = AttnMaskType.padding + # For 1.7 >= TE < 1.10 flash and fused path use padding mask with b11s and unfused path uses arbitrary mask with b1ss + elif is_te_min_version("1.7.0"): + if attention_backend in [AttnBackend.flash, AttnBackend.fused, AttnBackend.auto]: + attn_mask_dimensions = "b11s" + else: + if attn_mask_type != AttnMaskType.arbitrary: + warnings.warn( + f'For TE versions >= 1.7 but < 1.10 , unfused path supports only arbitrary mask. Setting attention mask from {attn_mask_type} to arbitray' + ) + self.transformer_layer_spec.submodules.self_attention.params[ + 'attn_mask_type' + ] = AttnMaskType.arbitrary + attn_mask_dimensions = "b1ss" + # For TE < 1.7 we only support unfused attention with b1ss and padding mask + else: + attn_mask_dimensions = "b1ss" + assert not (attention_backend in [AttnBackend.flash, AttnBackend.fused]), ( + "Flash and fused attention is not supported with transformer engine version " + "< 1.7. Set --attention-backend to unfused or leave it to be default (auto) or upgrade transformer engine >= 1.7" + ) + + return attn_mask_dimensions + + def bert_extended_attention_mask(self, attention_mask: Tensor) -> Tensor: + """Creates the extended attention mask + + Converts the attention mask of dimension + [batch size, 1, seq len] to [batch size, 1, seq len, seq len] + or [batch size, 1, 1, seq_len] and makes it binary + + Args: + attention_mask (Tensor): The input attention mask + + Returns: + Tensor: The extended binary attention mask + """ + # We create a 3D attention mask from a 2D tensor mask. + if self.attn_mask_dimensions == "b1ss": + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + # [b, 1, s, s] + extended_attention_mask = attention_mask_bss.unsqueeze(1) + else: + # [b, 1, 1, s] + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + + # Convert attention mask to binary: + extended_attention_mask = extended_attention_mask < 0.5 + + return extended_attention_mask + + def bert_position_ids(self, token_ids): + """Position ids for bert model""" + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + def set_input_tensor(self, input_tensor: Tensor) -> None: + """Sets input tensor to the model. + + See megatron.model.transformer.set_input_tensor() + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' + self.encoder.set_input_tensor(input_tensor[0]) + + def forward( + self, + input_ids: Tensor, + attention_mask: Tensor, + tokentype_ids: Tensor = None, + lm_labels: Tensor = None, + inference_params=None, + ): + """Forward function of BERT model + + Forward function of the BERT Model This function passes the input tensors + through the embedding layer, and then the encoder and finally into the post + processing layer (optional). + + It either returns the Loss values if labels are given or the final hidden units + """ + extended_attention_mask = self.bert_extended_attention_mask(attention_mask) + + if parallel_state.is_pipeline_first_stage(): + input_ids = input_ids + position_ids = self.bert_position_ids(input_ids) + else: + position_ids = None + input_ids = None + + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding( + input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids + ) + else: + # intermediate stage of pipeline + # encoder will get hidden_states from encoder.input_tensor + encoder_input = None + + # Rotary positional embeddings (Why not move this into BERT/GPTEmberdding ?) + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.encoder, encoder_input, self.config + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # Run encoder. + hidden_states = self.encoder( + hidden_states=encoder_input, + attention_mask=extended_attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) + if not self.post_process: + return hidden_states + + if self.add_binary_head: + pooled_output = self.pooler(hidden_states, 0) + + if self.return_embeddings: + embeddings = torch.transpose(hidden_states, 0, 1) + masks = torch.sum(attention_mask, dim=1) + # Collect masked embeddings. + output = torch.zeros( + size=(embeddings.shape[0], embeddings.shape[2]), + dtype=torch.float32, + device=torch.cuda.current_device(), + ) + for i, (embedding, mask) in enumerate(zip(embeddings, masks)): + output[i, :] = torch.mean(embedding[1 : mask - 1], dim=0) + return output + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + hidden_states_after_lm_head = self.lm_head(hidden_states=hidden_states) + logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight) + + binary_logits = None + if self.binary_head is not None: + binary_logits = self.binary_head(pooled_output) + + if lm_labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous(), binary_logits + + loss = self.compute_language_model_loss(lm_labels, logits) + + return loss, binary_logits diff --git a/megatron/core/models/bert/pooler.py b/megatron/core/models/bert/pooler.py new file mode 100644 index 0000000000..e0de1a845a --- /dev/null +++ b/megatron/core/models/bert/pooler.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import torch +from torch import Tensor + +from megatron.core import tensor_parallel +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import get_linear_layer + + +class Pooler(MegatronModule): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Args: + hidden_size (int): The hidden size_ + init_method (callable): weight initialization method for the linear layer. bias is set to zero. + config (TransformerConfig): The transformer configuration + sequence_parallel (bool): Using squence parallel ? Defaults to False + """ + + def __init__( + self, + hidden_size: int, + init_method: callable, + config: TransformerConfig, + sequence_parallel: bool = False, + ): + super(Pooler, self).__init__(config) + # TODO: Shoudl switch this to TE ? + self.dense = get_linear_layer( + hidden_size, hidden_size, init_method, config.perform_initialization + ) + self.sequence_parallel = sequence_parallel + + def forward(self, hidden_states: Tensor, sequence_index=0): + # hidden_states: [s, b, h] + # sequence_index: index of the token to pool. + + # gather data along sequence dimensions + # same pooler is run on all tensor parallel nodes + if self.sequence_parallel: + hidden_states = tensor_parallel.gather_from_sequence_parallel_region( + hidden_states, tensor_parallel_output_grad=False + ) + + pooled = hidden_states[sequence_index, :, :] + pooled = self.dense(pooled) + pooled = torch.tanh(pooled) + return pooled diff --git a/megatron/core/models/common/embeddings/__init__.py b/megatron/core/models/common/embeddings/__init__.py new file mode 100644 index 0000000000..865f96da5d --- /dev/null +++ b/megatron/core/models/common/embeddings/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from .rope_utils import apply_rotary_pos_emb +from .rotary_pos_embedding import RotaryEmbedding +from .yarn_rotary_pos_embedding import YarnRotaryEmbedding, _yarn_get_mscale diff --git a/megatron/core/models/common/embeddings/language_model_embedding.py b/megatron/core/models/common/embeddings/language_model_embedding.py index 5158f4c0af..2c7fec6564 100644 --- a/megatron/core/models/common/embeddings/language_model_embedding.py +++ b/megatron/core/models/common/embeddings/language_model_embedding.py @@ -1,6 +1,6 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -from typing import Literal, Optional +from typing import Literal import torch from torch import Tensor @@ -8,22 +8,21 @@ from megatron.core import tensor_parallel from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import ( - make_sharded_tensor_for_checkpoint, - make_tp_sharded_tensor_for_checkpoint, -) class LanguageModelEmbedding(MegatronModule): """Language model embeddings. - Arguments: + Args: config (TransformerConfig): config object with all necessary configs for TransformerBlock vocab_size (int): vocabulary size max_sequence_length (int): maximum size of sequence. This is used for positional embedding add_position_embedding (bool): Add a position embedding. - embedding_dropout_prob float): dropout probability for embeddings + embedding_dropout_prob (float): dropout probability for embeddings + num_tokentypes (int): Set to 0 without binary head, and 2 with a binary head. Defaults to 0. + scatter_to_sequence_parallel (bool): Set to False to disable scatter of embedding + across sequence parallel region. Defaults to True. """ def __init__( @@ -31,7 +30,9 @@ def __init__( config: TransformerConfig, vocab_size: int, max_sequence_length: int, - position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute', + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute', + num_tokentypes: int = 0, + scatter_to_sequence_parallel: bool = True, ): super().__init__(config=config) @@ -39,12 +40,21 @@ def __init__( self.vocab_size: int = vocab_size self.max_sequence_length: int = max_sequence_length self.add_position_embedding: bool = position_embedding_type == 'learned_absolute' + self.num_tokentypes = num_tokentypes + self.scatter_to_sequence_parallel = scatter_to_sequence_parallel + self.reduce_scatter_embeddings = ( + (not self.add_position_embedding) + and self.num_tokentypes <= 0 + and self.config.sequence_parallel + and self.scatter_to_sequence_parallel + ) # Word embeddings (parallel). self.word_embeddings = tensor_parallel.VocabParallelEmbedding( num_embeddings=self.vocab_size, embedding_dim=self.config.hidden_size, init_method=self.config.init_method, + reduce_scatter_embeddings=self.reduce_scatter_embeddings, config=self.config, ) @@ -58,6 +68,16 @@ def __init__( if self.config.perform_initialization: self.config.init_method(self.position_embeddings.weight) + if self.num_tokentypes > 0: + self.tokentype_embeddings = torch.nn.Embedding( + self.num_tokentypes, self.config.hidden_size + ) + # Initialize the token-type embeddings. + if self.config.perform_initialization: + self.config.init_method(self.tokentype_embeddings.weight) + else: + self.tokentype_embeddings = None + # Embeddings dropout self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout) @@ -67,12 +87,18 @@ def zero_parameters(self): self.word_embeddings.weight.shared = True self.position_embeddings.weight.data.fill_(0) self.position_embeddings.weight.shared = True + if self.num_tokentypes > 0: + self.tokentype_embeddings.weight.data.fill_(0) + self.tokentype_embeddings.weight.shared = True + + def forward(self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: int = None) -> Tensor: + """Forward pass of the embedding module. - def forward(self, input_ids: Tensor, position_ids: Tensor) -> Tensor: - """Forward pass of the embedding module Args: input_ids (Tensor): The input tokens position_ids (Tensor): The position id's used to calculate position embeddings + tokentype_ids (int): The token type ids. Used when args.bert_binary_head is + set to True. Defaults to None Returns: Tensor: The output embeddings @@ -84,8 +110,17 @@ def forward(self, input_ids: Tensor, position_ids: Tensor) -> Tensor: else: embeddings = word_embeddings - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() + if not self.reduce_scatter_embeddings: + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + if tokentype_ids is not None: + assert self.tokentype_embeddings is not None + # [b s h] -> [s b h] (So that it can be added with embeddings) + tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2) + embeddings = embeddings + tokentype_embedding + else: + assert self.tokentype_embeddings is None # If the input flag for fp32 residual connection is set, convert for float. if self.config.fp32_residual_connection: @@ -93,41 +128,16 @@ def forward(self, input_ids: Tensor, position_ids: Tensor) -> Tensor: # Dropout. if self.config.sequence_parallel: - embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + if not self.reduce_scatter_embeddings and self.scatter_to_sequence_parallel: + embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + # `scatter_to_sequence_parallel_region` returns a view, which prevents + # the original tensor from being garbage collected. Clone to facilitate GC. + # Has a small runtime cost (~0.5%). + if self.config.clone_scatter_output_in_embedding and self.scatter_to_sequence_parallel: + embeddings = embeddings.clone() with tensor_parallel.get_cuda_rng_tracker().fork(): embeddings = self.embedding_dropout(embeddings) else: embeddings = self.embedding_dropout(embeddings) return embeddings - - def sharded_state_dict(self, prefix=''): - - sharded_state_dict = {} - - word_embeddings_prefix = f'{prefix}word_embeddings.' - word_embeddings_state_dict = self.word_embeddings.state_dict( - prefix=word_embeddings_prefix, keep_vars=True - ) - - sharded_word_embeddings_key = f'{word_embeddings_prefix}weight' - sharded_word_embeddings_tensor = make_tp_sharded_tensor_for_checkpoint( - tensor=word_embeddings_state_dict[sharded_word_embeddings_key], - key=sharded_word_embeddings_key, - allow_shape_mismatch=True, - ) - sharded_state_dict[sharded_word_embeddings_key] = sharded_word_embeddings_tensor - - if self.add_position_embedding: - position_embeddings_prefix = f'{prefix}position_embeddings.' - position_embeddings_state_dict = self.position_embeddings.state_dict( - prefix=position_embeddings_prefix, keep_vars=True - ) - sharded_position_embeddings_key = f'{position_embeddings_prefix}weight' - sharded_position_embeddings_tensor = make_sharded_tensor_for_checkpoint( - tensor=position_embeddings_state_dict[sharded_position_embeddings_key], - key=sharded_position_embeddings_key, - ) - sharded_state_dict[sharded_position_embeddings_key] = sharded_position_embeddings_tensor - - return sharded_state_dict diff --git a/megatron/core/models/common/embeddings/language_module/language_module.py b/megatron/core/models/common/embeddings/language_module/language_module.py deleted file mode 100644 index 473a2970bd..0000000000 --- a/megatron/core/models/common/embeddings/language_module/language_module.py +++ /dev/null @@ -1,102 +0,0 @@ -import logging - -import torch -from torch import Tensor - -from megatron.core import parallel_state, tensor_parallel -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig - - -class LanguageModule(MegatronModule): - """Base language module that has common helper functions used across GPT, BERT etc. - - Args: - config (TransformerConfig): Input transformer config for the model - """ - - def __init__(self, config: TransformerConfig) -> None: - super().__init__(config=config) - - def set_input_tensor(self, input_tensor: Tensor) -> None: - """Sets input tensor to the model. - - See megatron.model.transformer.set_input_tensor() - - Args: - input_tensor (Tensor): Sets the input tensor for the model. - """ - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - - assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt' - self.decoder.set_input_tensor(input_tensor[0]) - - def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor: - """Computes the language model loss (Cross entropy across vocabulary) - - Args: - labels (Tensor): The labels of dimension [batch size, seq length] - logits (Tensor): The final logits returned by the output layer of the transformer model - - Returns: - Tensor: Loss tensor of dimensions [batch size, sequence_length] - """ - # [b s] => [s b] - labels = labels.transpose(0, 1).contiguous() - loss = tensor_parallel.vocab_parallel_cross_entropy(logits.float(), labels) - - # [s b] => [b, s] - loss = loss.transpose(0, 1).contiguous() - return loss - - def initialize_last_stage_with_word_embeddings(self) -> None: - """Intializes the word embeddings in the final stage. - - This function just initalizes word embeddings in the final stage, when we are - using pipeline parallelism and sharind word embeddings. Nothing to do if we - arn't sharing weights or aren't using Pipeline parallelism - """ - if not self.share_embeddings_and_output_weights or (self.pre_process and self.post_process): - return - - if self.post_process and not self.pre_process: - assert not parallel_state.is_pipeline_first_stage() - # set word_embeddings weights to 0 here, then copy first - # stage's weights using all_reduce below. - self.output_layer.weight.data.fill_(0) - self.output_layer.weight.shared = True - - # Parameters are shared between the word embeddings layers, and the - # heads at the end of the model. In a pipelined setup with more than - # one stage, the initial embedding layer and the head are on different - # workers, so we do the following: - # 1. Create a second copy of word_embeddings on the last stage, with - # initial parameters of 0.0. - # 2. Do an all-reduce between the first and last stage to ensure that - # the two copies of word_embeddings start off with the same - # parameter values. - # 3. In the training loop, before an all-reduce between the grads of - # the two word_embeddings layers to ensure that every applied weight - # update is the same on both stages. - - # Ensure that first and last stages have the same initial parameter - # values. - if torch.distributed.is_initialized(): - if parallel_state.is_rank_in_embedding_group(): - weight = self.shared_embedding_or_output_weight() - torch.distributed.all_reduce( - weight.data, group=parallel_state.get_embedding_group() - ) - - elif not getattr(LanguageModule, "embedding_warning_printed", False): - logging.getLogger(__name__).warning( - "Distributed processes aren't initialized, so the output layer " - "is not initialized with weights from the word embeddings. " - "If you are just manipulating a model this is fine, but " - "this needs to be handled manually. If you are training " - "something is definitely wrong." - ) - LanguageModule.embedding_warning_printed = True diff --git a/megatron/core/models/common/embeddings/relative_pos_embedding.py b/megatron/core/models/common/embeddings/relative_pos_embedding.py new file mode 100644 index 0000000000..af17bce1cc --- /dev/null +++ b/megatron/core/models/common/embeddings/relative_pos_embedding.py @@ -0,0 +1,173 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import logging +import math +from typing import Callable + +import torch +from torch import Tensor, nn + +from megatron.core.inference_params import InferenceParams +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig + +logger = logging.getLogger(__name__) + + +__all__ = ['RelativePositionEmbedding'] + + +class RelativePositionEmbedding(nn.Module): + """Relative Position Embedding for language model. + + Args: + + """ + + def __init__( + self, + bidirectional: bool, + init_method: Callable, + num_attention_heads: int, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + ) -> None: + super().__init__() + + self.bidirectional = bidirectional + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.relative_attention_bias = torch.nn.Embedding( + self.relative_attention_num_buckets, num_attention_heads + ) + init_method(self.relative_attention_bias.weight) + + def _relative_position_bucket( + self, relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from HuggingFace T5 Model: + https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/ + src/transformers/models/t5/modeling_t5.py#L397 + + Translate relative position to a bucket number for relative attention. + The relative position is defined as memory_position - query_position, i.e. the + distance in tokens from the attending position to the attended-to position. + If bidirectional=False, then positive relative positions are invalid. We use + smaller buckets for small absolute relative_position and larger buckets for + larger absolute relative_positions. All relative positions >=max_distance map + to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the + model has been trained on. + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, + containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger + # bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def _compute_bias(self, query_length, key_length): + """ + Adapted from HuggingFace T5 Model + https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/ + src/transformers/models/t5/modeling_t5.py#L444C9-L444C21 + + Compute binned relative position bias + + Args: + query_length (int): The length of the query sequence + (e.g., the input sequence in attention). + key_length (int): The length of the key sequence + (e.g., the sequence to compare against in attention). + + Returns: + torch.Tensor: A tensor representing the relative position bias, with shape + (1, num_heads, query_length, key_length). + """ + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + + relative_position = memory_position - context_position # shape(query_length,key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=self.bidirectional, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape(query_length,key_length,num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape(1, num_heads,query_length,key_length) + return values + + @staticmethod + def get_relative_seq_len( + inference_params: InferenceParams, + transformer: TransformerBlock, + transformer_input: Tensor, + transformer_config: TransformerConfig, + ) -> float: + """Function to get the rotary sequence length. + + Args: + inference_params : Used during Inference time + transformer (TransformerBlock): The transformer block (decoder/encoder) used + by the model + transformer_input (Tensor): Input tensor to the transformer + transformer_config (TransformerConfig): Transformer config used by the model + + Returns: + float: The rotary sequence length + """ + if inference_params is not None: + relative_seq_len = inference_params.max_sequence_length + else: + if transformer.input_tensor is not None: + relative_seq_len = transformer.input_tensor.size(0) + else: + relative_seq_len = transformer_input.size(0) + + if transformer_config.sequence_parallel: + relative_seq_len *= transformer_config.tensor_model_parallel_size + + return relative_seq_len + + def forward(self, query_seq_length, key_seq_length): + """ + Args: + Returns: + """ + return self._compute_bias(query_seq_length, key_seq_length) diff --git a/megatron/core/models/common/embeddings/rope_utils.py b/megatron/core/models/common/embeddings/rope_utils.py new file mode 100644 index 0000000000..3dd5193ca2 --- /dev/null +++ b/megatron/core/models/common/embeddings/rope_utils.py @@ -0,0 +1,261 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from megatron.core.transformer.transformer_config import TransformerConfig + +import logging + +import torch +from torch import Tensor + +from megatron.core import parallel_state +from megatron.core.utils import is_te_min_version + +logger = logging.getLogger(__name__) + +# Prefer fused RoPE from Apex as we need the `transpose_output_memory` argument for the bshd trick. +# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2469. +try: + from apex.transformer.functional import fused_apply_rotary_pos_emb +except ImportError: + try: + from megatron.core.extensions.transformer_engine import fused_apply_rotary_pos_emb + except: + fused_apply_rotary_pos_emb = None + + +try: + from megatron.core.extensions.transformer_engine import fused_apply_rotary_pos_emb_thd +except ImportError: + try: + from apex.transformer.functional import fused_apply_rotary_pos_emb_thd + except ImportError: + fused_apply_rotary_pos_emb_thd = None + + +try: + from flash_attn.layers.rotary import apply_rotary_emb as apply_rotary_emb_flash +except ImportError: + apply_rotary_emb_flash = None + + +__all__ = ['apply_rotary_emb_flash'] + + +def get_pos_emb_on_this_cp_rank(pos_emb: Tensor, seq_dim: int) -> Tensor: + """Get the position embedding on the current context parallel rank. + + Args: + pos_emb (Tensor): Positional embedding tensor + seq_dim (int): Sequence dimension + """ + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + cp_idx = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + pos_emb = pos_emb.view( + *pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :] + ) + pos_emb = pos_emb.index_select(seq_dim, cp_idx) + pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :]) + return pos_emb + + +def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor: + """Change sign so the last dimension becomes [-odd, +even] + + Args: + x (Tensor): Input tensor + + Returns: + Tensor: Tensor rotated half + """ + if not rotary_interleaved: + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x_new = torch.stack((-x2, x1), dim=-1) + return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1) + + +def _apply_rotary_pos_emb_bshd( + t: Tensor, + freqs: Tensor, + rotary_interleaved: bool = False, + multi_latent_attention: bool = False, + mscale: float = 1.0, +) -> Tensor: + """Apply rotary positional embedding to input tensor T. + + check https://kexue.fm/archives/8265 for detailed formulas + + Args: + t (Tensor): Input tensor T is of shape [seq_length, ... , dim] + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] + + Returns: + Tensor: The input tensor after applying RoPE + """ + rot_dim = freqs.shape[-1] + + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + if multi_latent_attention: + x1 = t[..., 0::2] + x2 = t[..., 1::2] + t = torch.cat((x1, x2), dim=-1) + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + cos_ = (torch.cos(freqs) * mscale).to(t.dtype) + sin_ = (torch.sin(freqs) * mscale).to(t.dtype) + + t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_) + return torch.cat((t, t_pass), dim=-1) + + +def _get_thd_freqs_on_this_cp_rank(cp_rank: int, cp_size: int, x: Tensor, freqs: Tensor) -> Tensor: + if cp_size > 1: + cp_seg = x.size(0) // 2 + full_seqlen = cp_size * x.size(0) + return torch.cat( + [ + freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg], + freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg], + ] + ) + else: + return freqs[: x.size(0)] + + +def _apply_rotary_pos_emb_thd( + t: Tensor, + cu_seqlens: Tensor, + freqs: Tensor, + rotary_interleaved: bool = False, + multi_latent_attention: bool = False, + mscale: float = 1.0, +) -> Tensor: + """A baseline implementation of applying RoPE for `thd` format. + + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + + Returns: + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + """ + + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + cu_seqlens = cu_seqlens // cp_size + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + + return torch.cat( + [ + _apply_rotary_pos_emb_bshd( + x.unsqueeze(1), + _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs), + rotary_interleaved=rotary_interleaved, + multi_latent_attention=multi_latent_attention, + mscale=mscale, + ) + for x in torch.split(t, seqlens) + ] + ).squeeze(1) + + +def apply_rotary_pos_emb( + t: Tensor, + freqs: Tensor, + config: TransformerConfig, + cu_seqlens: Optional[Tensor] = None, + mscale: float = 1.0, +): + """ + Reroute to the appropriate apply_rotary_pos_emb function depending on + fused/unfused kernels, or bshd (conventional) / thd (packed seq) format + """ + + if config.apply_rope_fusion: + if cu_seqlens is None: + assert fused_apply_rotary_pos_emb is not None, "apply_rope_fusion is not available." + return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True) + else: + assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available." + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + if not is_te_min_version("1.11.0", check_equality=False): + raise ValueError("Only TE >= 1.12 supports RoPE fusion for THD format with CP.") + return fused_apply_rotary_pos_emb_thd( + t, + cu_seqlens, + freqs, + cp_size=cp_size, + cp_rank=parallel_state.get_context_parallel_rank(), + ) + else: + return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs) + else: + if cu_seqlens is None: + return _apply_rotary_pos_emb_bshd( + t, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + ) + else: + return _apply_rotary_pos_emb_thd( + t, + cu_seqlens, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + ) + + +def apply_rotary_pos_emb_with_cos_sin( + t: Tensor, cos: Tensor, sin: Tensor, rotary_interleaved: bool = False +) -> Tensor: + """ + This function applies rotary positional embedding to the target tensor t + using precomputed cos and sin of size (seq_len, d_rot / 2) + """ + cos = cos.to(t.dtype) + sin = sin.to(t.dtype) + + if apply_rotary_emb_flash is None: + # Combine cos and sin into freqs + freqs = torch.stack([cos, sin], dim=-1).flatten(start_dim=-2) + + # Expand freqs to match t's shape + while freqs.dim() < t.dim(): + freqs = freqs.unsqueeze(1) + freqs = freqs.expand(t.shape[:-1] + (-1,)) + + y = _apply_rotary_pos_emb_bshd( + t, + freqs, + rotary_interleaved=rotary_interleaved, + multi_latent_attention=False, + mscale=1.0, + ) + else: + # Use Flash Attention's optimized kernel for rotary embedding + t = t.permute(1, 0, 2, 3) + y = apply_rotary_emb_flash(t, cos, sin, rotary_interleaved) + y = y.permute(1, 0, 2, 3) + + return y diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index d098e4561f..e3923d016b 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -7,64 +7,116 @@ if TYPE_CHECKING: from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_block import TransformerBlock + from megatron.core.inference_params import InferenceParams + from megatron.core.packed_seq_params import PackedSeqParams + +import logging +import math +from functools import lru_cache import torch from torch import Tensor, nn from megatron.core import parallel_state +from megatron.core.models.common.embeddings.rope_utils import ( # for backward compatibility; pylint: disable=unused-import + _apply_rotary_pos_emb_bshd, + _apply_rotary_pos_emb_thd, + _rotate_half, + apply_rotary_pos_emb, + get_pos_emb_on_this_cp_rank, +) -__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb'] +logger = logging.getLogger(__name__) -def get_pos_emb_on_this_cp_rank(pos_emb, seq_dim): - cp_size = parallel_state.get_context_parallel_world_size() - cp_rank = parallel_state.get_context_parallel_rank() - cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=pos_emb.device) - pos_emb = pos_emb.view( - *pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :] - ) - pos_emb = pos_emb.index_select(seq_dim, cp_idx) - pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :]) - return pos_emb +__all__ = ['RotaryEmbedding'] class RotaryEmbedding(nn.Module): """Rotary Embedding for language model. Args: - kv_channels (int): Projection weights dimension in multi-head attention. Obtained from transformer config - rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. - seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None + kv_channels (int): Projection weights dimension in multi-head attention. Obtained + from transformer config + rotary_percent (float): Percent of rotary dimension to use for rotary position + embeddings. + rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings. + Defaults to False. + seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE + for longer sequences. The value must be a float larger than 1.0. Defaults to None + rotary_base (int, optional): Base period for rotary position embeddings. Defaults to + 10000. + rope_scaling (bool, optional): Apply rope scaling as used in llama 3.x. + rope_scaling_factor (float, optional): rope scaling factor in llama 3.x. Defaults to 8. + use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly + on the GPU. Defaults to False """ def __init__( - self, kv_channels: int, rotary_percent: float, seq_len_interpolation_factor: float = None + self, + kv_channels: int, + rotary_percent: float, + rotary_interleaved: bool = False, + seq_len_interpolation_factor: float = None, + rotary_base: int = 10000, + rope_scaling: bool = False, + rope_scaling_factor: float = 8.0, + use_cpu_initialization: bool = False, ) -> None: super().__init__() dim = kv_channels if rotary_percent < 1.0: dim = int(dim * rotary_percent) + self.rotary_interleaved = rotary_interleaved self.seq_len_interpolation_factor = seq_len_interpolation_factor + device = 'cpu' if use_cpu_initialization else torch.cuda.current_device() self.inv_freq = 1.0 / ( - 10000 - ** ( - torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) - / dim - ) + rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) ) - def forward(self, max_seq_len: int, offset: int = 0) -> Tensor: - """Forward pass of RoPE embedding. + if rope_scaling: + self.inv_freq = self._apply_scaling(self.inv_freq, factor=rope_scaling_factor) - Args: - max_seq_len (int): Maximum size of sequence - offset (int, optional): _description_. Defaults to 0. + def _apply_scaling( + self, + freqs, + factor=8, + low_freq_factor=1, + high_freq_factor=4, + original_max_position_embeddings=8192, + ): + # This implementation is adapted from: + # https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343 + + factor = factor # `8` in the original implementation + low_freq_factor = low_freq_factor # `1` in the original implementation + high_freq_factor = high_freq_factor # `4` in the original implementation + old_context_len = original_max_position_embeddings # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / freqs + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, freqs / factor, freqs) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - Returns: - Tensor: Embeddings after applying RoPE. - """ + return inv_freq_llama + + def get_freqs_non_repeated(self, max_seq_len: int, offset: int = 0) -> Tensor: + """Generates matrix of frequencies based on positions in the sequence, + used to create positional encodings""" seq = ( torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + offset @@ -73,14 +125,48 @@ def forward(self, max_seq_len: int, offset: int = 0) -> Tensor: if self.seq_len_interpolation_factor is not None: seq *= 1 / self.seq_len_interpolation_factor - freqs = torch.outer(seq, self.inv_freq) + freqs = torch.outer(seq, self.inv_freq) # [seq len, dim] + + return freqs + + def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor): + """Cosine and sine values for RoPE are precomputed for all positions up to the maximum + sequence length""" + freqs = self.get_freqs_non_repeated(max_seq_len, offset) + cos = torch.cos(freqs) + sin = torch.sin(freqs) + return cos, sin + + @lru_cache(maxsize=32) + def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: + """Forward pass of RoPE embedding. + + Args: + max_seq_len (int): Maximum size of sequence + offset (int, optional): RoPE offset. Defaults to 0. + packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. + + Returns: + Tensor: Embeddings after applying RoPE. + """ + if self.inv_freq.device.type == 'cpu': + # move `inv_freq` to GPU once at the first micro-batch forward pass + self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device()) + + freqs = self.get_freqs_non_repeated(max_seq_len, offset) # first part even vector components, second part odd vector components, # 2 * dim in dimension size - emb = torch.cat((freqs, freqs), dim=-1) + if not self.rotary_interleaved: + emb = torch.cat((freqs, freqs), dim=-1) + else: + emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view( + freqs.shape[0], -1 + ) # emb [seq_length, .., dim] emb = emb[:, None, None, :] - if parallel_state.get_context_parallel_world_size() > 1: - # slice rotary_pos_emb along sequence dimension and select the parition of the current CP rank + if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq: + # slice rotary_pos_emb along sequence dimension and select the parition of the current + # CP rank emb = get_pos_emb_on_this_cp_rank(emb, 0) return emb @@ -90,23 +176,30 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def get_rotary_seq_len( self, - inference_params, + inference_params: InferenceParams, transformer: TransformerBlock, transformer_input: Tensor, transformer_config: TransformerConfig, + packed_seq_params: PackedSeqParams, ) -> float: """Function to get the rotary sequence length. Args: inference_params : Used during Inference time - transformer (TransformerBlock): The transformer block (decoder/encoder) used by the model - transformer_input (Tensor): _description_ + transformer (TransformerBlock): The transformer block (decoder/encoder) used + by the model + transformer_input (Tensor): Input tensor to the transformer transformer_config (TransformerConfig): Transformer config used by the model + packed_seq_params (PackedSeqParams): Packed sequence params Returns: float: The rotary sequence length """ - if inference_params is not None: + if packed_seq_params is not None: + # max_seqlen are the max sequence length in the packed sequence before being divived + # by the tp and cp size. + return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv) + elif inference_params is not None: rotary_seq_len = inference_params.max_sequence_length else: if transformer.input_tensor is not None: @@ -120,43 +213,3 @@ def get_rotary_seq_len( rotary_seq_len *= transformer_config.context_parallel_size return rotary_seq_len - - -def _rotate_half(x: Tensor) -> Tensor: - """Change sign so the last dimension becomes [-odd, +even] - - Args: - x (Tensor): Input tensor - - Returns: - Tensor: Tensor rotated half - """ - - x1, x2 = torch.chunk(x, 2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor: - """Apply rotary positional embedding to input tensor T. - - check https://kexue.fm/archives/8265 for detailed formulas - - Args: - t (Tensor): Input tensor T is of shape [seq_length, ... , dim] - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] - - Returns: - Tensor: The input tensor after applying RoPE - """ - rot_dim = freqs.shape[-1] - - # ideally t_pass is empty so rotary pos embedding is applied to all tensor t - t, t_pass = t[..., :rot_dim], t[..., rot_dim:] - - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - cos_ = torch.cos(freqs).to(t.dtype) - sin_ = torch.sin(freqs).to(t.dtype) - - t = (t * cos_) + (_rotate_half(t) * sin_) - return torch.cat((t, t_pass), dim=-1) diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py new file mode 100644 index 0000000000..3ab155dcdb --- /dev/null +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -0,0 +1,179 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +from __future__ import annotations + +import logging +import math +from functools import lru_cache + +import torch +from torch import Tensor + +from megatron.core import parallel_state +from megatron.core.models.common.embeddings.rope_utils import get_pos_emb_on_this_cp_rank +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + +logger = logging.getLogger(__name__) + + +class YarnRotaryEmbedding(RotaryEmbedding): + """Yarn Rotary Embedding for language model. + + Args: + kv_channels (int): Projection weights dimension in multi-head attention. Obtained from + transformer config + rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. + rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings. + Defaults to False. + seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for + longer sequences. The value must be a float larger than 1.0. Defaults to None + rotary_base (float, optional): Base period for rotary position embeddings. Defaults to + 10000. + use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly on + the GPU. Defaults to False + scaling_factor (float, optional): Scaling factor for Yarn RoPE. Defaults to 1.0. + original_max_position_embeddings (int, optional): Original maximum position embeddings + length. Defaults to 4096. + beta_fast (float, optional): Fast beta value for Yarn RoPE. Defaults to 32. + beta_slow (float, optional): Slow beta value for Yarn RoPE. Defaults to 1. + mscale (float, optional): Mscale value for Yarn RoPE. Defaults to 1. + mscale_all_dim (float, optional): Mscale all dim value for Yarn RoPE. Defaults to 0. + """ + + def __init__( + self, + kv_channels: int, + rotary_percent: float = 1.0, + rotary_interleaved: bool = False, + seq_len_interpolation_factor: float = None, + rotary_base: float = 10000.0, + use_cpu_initialization: bool = False, + scaling_factor: float = 1.0, + original_max_position_embeddings: int = 4096, + beta_fast: float = 32.0, + beta_slow: float = 1.0, + mscale: float = 1.0, + mscale_all_dim: float = 0.0, + ): + self.dim = kv_channels + self.rotary_base = rotary_base + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + + device = 'cpu' if use_cpu_initialization else torch.cuda.current_device() + self.inv_freq_extra = 1.0 / ( + self.rotary_base + ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) + ) + self.inv_freq_inter = 1.0 / ( + self.scaling_factor + * self.rotary_base + ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) + ) + super().__init__( + kv_channels, + rotary_percent, + rotary_interleaved, + seq_len_interpolation_factor, + rotary_base, + use_cpu_initialization, + ) + + @lru_cache(maxsize=32) + def forward(self, max_seq_len: int, offset: int = 0) -> Tensor: + """Forward pass of Yarn Rotary Embedding. + + Args: + max_seq_len (int): Maximum size of sequence + offset (int, optional): RoPE offset. Defaults to 0. + + Returns: + Tensor: Embeddings after applying Yarn RoPE. + """ + assert ( + not self.rotary_interleaved + ), "Yarn RoPE does not support interleaved rotary embeddings" + + if self.inv_freq_extra.device.type == 'cpu': + # move `inv_freq_extra` to GPU once at the first micro-batch forward pass + self.inv_freq_extra = self.inv_freq_extra.to(device=torch.cuda.current_device()) + + if self.inv_freq_inter.device.type == 'cpu': + # move `inv_freq_inter` to GPU once at the first micro-batch forward pass + self.inv_freq_inter = self.inv_freq_inter.to(device=torch.cuda.current_device()) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.dim, + self.rotary_base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - _yarn_linear_ramp_mask(low, high, self.dim // 2).to( + device=self.inv_freq_extra.device, dtype=torch.float32 + ) + inv_freq = self.inv_freq_inter * (1 - inv_freq_mask) + self.inv_freq_extra * inv_freq_mask + + seq = ( + torch.arange( + max_seq_len, device=self.inv_freq_extra.device, dtype=self.inv_freq_extra.dtype + ) + + offset + ) + + freqs = torch.outer(seq, inv_freq) + + _mscale = float( + _yarn_get_mscale(self.scaling_factor, self.mscale) + / _yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + # emb [seq_length, .., dim] + emb = emb[:, None, None, :] + if parallel_state.get_context_parallel_world_size() > 1: + # slice rotary_pos_emb along sequence dimension + # and select the parition of the current CP rank + emb = get_pos_emb_on_this_cp_rank(emb, 0) + return emb, _mscale + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: float, dim: int, rotary_base: float = 10000, max_position_embeddings: int = 2048 +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(rotary_base) + ) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: float, + high_rot: float, + dim: int, + rotary_base: float = 10000, + max_position_embeddings: int = 2048, +) -> tuple[int, int]: + low = math.floor(_yarn_find_correction_dim(low_rot, dim, rotary_base, max_position_embeddings)) + high = math.ceil(_yarn_find_correction_dim(high_rot, dim, rotary_base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(min: float, max: float, dim: int) -> Tensor: + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 diff --git a/megatron/mpu/tests/__init__.py b/megatron/core/models/common/language_module/__init__.py similarity index 100% rename from megatron/mpu/tests/__init__.py rename to megatron/core/models/common/language_module/__init__.py diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py new file mode 100644 index 0000000000..cb26be122f --- /dev/null +++ b/megatron/core/models/common/language_module/language_module.py @@ -0,0 +1,244 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import logging +import os +from typing import Optional, Tuple + +import torch +from torch import Tensor + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy +from megatron.core.transformer.enums import AttnBackend +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint + + +class LanguageModule(MegatronModule): + """Base language module that has common helper functions used across GPT, BERT etc. + + Args: + config (TransformerConfig): Input transformer config for the model + """ + + def __init__(self, config: TransformerConfig) -> None: + super().__init__(config=config) + self._set_attention_backend() + + # pylint: disable=line-too-long + def _set_attention_backend(self): + """Set attention backend + + Transformer engine works based on optout. By default all three attention backend flags are set to 1. So if the user choses a particular attention backend we set the other two to 0. If the user choses local, we set all 3 TE env variables to 0. + """ + + def check_and_set_env_variable( + env_variable_name: str, expected_value: int, attn_type: AttnBackend + ) -> None: + current_value = os.getenv(env_variable_name) + assert current_value is None or current_value == str( + expected_value + ), f'{env_variable_name} set to {current_value}, but expected {expected_value} for attention backend type {attn_type.name}. unset NVTE_FLASH_ATTN, NVTE_FUSED_ATTN and NVTE_UNFUSED_ATTN. Use the --attention-backend argument if you want to choose between (flash/fused/unfused/auto/local). Default is auto.' + os.environ[env_variable_name] = str(expected_value) + + if self.config.attention_backend == AttnBackend.local: + check_and_set_env_variable("NVTE_FLASH_ATTN", 0, AttnBackend.flash) + check_and_set_env_variable("NVTE_FUSED_ATTN", 0, AttnBackend.flash) + check_and_set_env_variable("NVTE_UNFUSED_ATTN", 0, AttnBackend.flash) + elif self.config.attention_backend == AttnBackend.flash: + check_and_set_env_variable("NVTE_FLASH_ATTN", 1, AttnBackend.flash) + check_and_set_env_variable("NVTE_FUSED_ATTN", 0, AttnBackend.flash) + check_and_set_env_variable("NVTE_UNFUSED_ATTN", 0, AttnBackend.flash) + elif self.config.attention_backend == AttnBackend.fused: + check_and_set_env_variable("NVTE_FLASH_ATTN", 0, AttnBackend.fused) + check_and_set_env_variable("NVTE_FUSED_ATTN", 1, AttnBackend.fused) + check_and_set_env_variable("NVTE_UNFUSED_ATTN", 0, AttnBackend.fused) + elif self.config.attention_backend == AttnBackend.unfused: + check_and_set_env_variable("NVTE_FLASH_ATTN", 0, AttnBackend.unfused) + check_and_set_env_variable("NVTE_FUSED_ATTN", 0, AttnBackend.unfused) + check_and_set_env_variable("NVTE_UNFUSED_ATTN", 1, AttnBackend.unfused) + elif self.config.attention_backend == AttnBackend.auto: + check_and_set_env_variable("NVTE_FLASH_ATTN", 1, AttnBackend.auto) + check_and_set_env_variable("NVTE_FUSED_ATTN", 1, AttnBackend.auto) + check_and_set_env_variable("NVTE_UNFUSED_ATTN", 1, AttnBackend.auto) + + def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor: + """Computes the language model loss (Cross entropy across vocabulary) + + Args: + labels (Tensor): The labels of dimension [batch size, seq length] + logits (Tensor): The final logits returned by the output layer of the transformer model + + Returns: + Tensor: Loss tensor of dimensions [batch size, sequence_length] + """ + # [b s] => [s b] + labels = labels.transpose(0, 1).contiguous() + if self.config.cross_entropy_loss_fusion: + loss = fused_vocab_parallel_cross_entropy(logits, labels) + else: + loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels) + + # [s b] => [b, s] + loss = loss.transpose(0, 1).contiguous() + return loss + + def setup_embeddings_and_output_layer(self) -> None: + """Sets up embedding layer in first stage and output layer in last stage. + + This function initalizes word embeddings in the final stage when we are + using pipeline parallelism and sharing word embeddings, and sets up param + attributes on the embedding and output layers. + """ + + # Set `is_embedding_or_output_parameter` attribute. + if self.pre_process: + self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True + if self.post_process and self.output_layer.weight is not None: + self.output_layer.weight.is_embedding_or_output_parameter = True + + if not self.share_embeddings_and_output_weights: + return + + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + # Zero out wgrad if sharing embeddings between two layers on same + # pipeline stage to make sure grad accumulation into main_grad is + # correct and does not include garbage values (e.g., from torch.empty). + self.shared_embedding_or_output_weight().zero_out_wgrad = True + return + + if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process: + self.shared_embedding_or_output_weight().shared_embedding = True + + if self.post_process and not self.pre_process: + assert not parallel_state.is_pipeline_first_stage() + # set word_embeddings weights to 0 here, then copy first + # stage's weights using all_reduce below. + self.output_layer.weight.data.fill_(0) + self.output_layer.weight.shared = True + self.output_layer.weight.shared_embedding = True + + # Parameters are shared between the word embeddings layers, and the + # heads at the end of the model. In a pipelined setup with more than + # one stage, the initial embedding layer and the head are on different + # workers, so we do the following: + # 1. Create a second copy of word_embeddings on the last stage, with + # initial parameters of 0.0. + # 2. Do an all-reduce between the first and last stage to ensure that + # the two copies of word_embeddings start off with the same + # parameter values. + # 3. In the training loop, before an all-reduce between the grads of + # the two word_embeddings layers to ensure that every applied weight + # update is the same on both stages. + + # Ensure that first and last stages have the same initial parameter + # values. + if torch.distributed.is_initialized(): + if parallel_state.is_rank_in_embedding_group(): + weight = self.shared_embedding_or_output_weight() + weight.data = weight.data.cuda() + torch.distributed.all_reduce( + weight.data, group=parallel_state.get_embedding_group() + ) + + elif not getattr(LanguageModule, "embedding_warning_printed", False): + logging.getLogger(__name__).warning( + "Distributed processes aren't initialized, so the output layer " + "is not initialized with weights from the word embeddings. " + "If you are just manipulating a model this is fine, but " + "this needs to be handled manually. If you are training " + "something is definitely wrong." + ) + LanguageModule.embedding_warning_printed = True + + def shared_embedding_or_output_weight(self) -> Tensor: + """Gets the emedding weight or output logit weights when share embedding and output weights set to True. + + Returns: + Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight + """ + if self.pre_process: + return self.embedding.word_embeddings.weight + elif self.post_process: + return self.output_layer.weight + return None + + def sharded_state_dict( + self, + prefix: str = '', + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, + ) -> ShardedStateDict: + """Sharded state dict implementation that handles the output layer weights tying. + + Args: + prefix (str): Module name prefix. + sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. + metadata (Optional[Dict]): metadata controlling sharded state dict creation. + + Returns: + ShardedStateDict: sharded state dict for the LanguageModel + """ + assert not sharded_offsets, "Unexpected sharded offsets" + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight' + output_layer_weight_key = f'{prefix}output_layer.weight' + output_layer_bias_key = f'{prefix}output_layer.bias' + + if self.share_embeddings_and_output_weights: + self.tie_embeddings_and_output_weights_state_dict( + sharded_state_dict, output_layer_weight_key, first_stage_word_emb_key + ) + elif self.post_process: + # Make sure the output layer follows the embeddings padding logic + sharded_state_dict[output_layer_weight_key].allow_shape_mismatch = True + + # Regardless of sharing the output weights with embeddings, we must handle the bias padding + if self.post_process and output_layer_bias_key in sharded_state_dict: + sharded_state_dict[output_layer_bias_key].allow_shape_mismatch = True + + return sharded_state_dict + + def tie_embeddings_and_output_weights_state_dict( + self, + sharded_state_dict: ShardedStateDict, + output_layer_weight_key: str, + first_stage_word_emb_key: str, + ) -> None: + """Ties the embedding and output weights in a given sharded state dict. + + Args: + sharded_state_dict (ShardedStateDict): state dict with the weight to tie + output_layer_weight_key (str): key of the output layer weight in the state dict. + This entry will be replaced with a tied version + first_stage_word_emb_key (str): this must be the same as the + ShardedTensor.key of the first stage word embeddings. + + Returns: None, acts in-place + """ + if not self.post_process: + # No output layer + assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys() + return + + if self.pre_process: + # Output layer is equivalent to the embedding already + return + + # Replace the default output layer with a one sharing the weights with the embedding + del sharded_state_dict[output_layer_weight_key] + tensor = self.shared_embedding_or_output_weight() + last_stage_word_emb_replica_id = ( + 1, # copy of first stage embedding + 0, + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + + sharded_state_dict[output_layer_weight_key] = make_tp_sharded_tensor_for_checkpoint( + tensor=tensor, + key=first_stage_word_emb_key, + replica_id=last_stage_word_emb_replica_id, + allow_shape_mismatch=True, + ) diff --git a/megatron/core/models/common/vision_module/__init__.py b/megatron/core/models/common/vision_module/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/megatron/core/models/common/vision_module/vision_module.py b/megatron/core/models/common/vision_module/vision_module.py new file mode 100644 index 0000000000..5dc51873a4 --- /dev/null +++ b/megatron/core/models/common/vision_module/vision_module.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Megatron Vision Module.""" + +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig + + +# Note: This is only a stub at the moment. This will be expanded in follow-up changes. +class VisionModule(MegatronModule): + """Base vision module that has common helper functions used across CLIP, ViT, etc. + + Args: + config (TransformerConfig): Input transformer config for the model + """ + + def __init__(self, config: TransformerConfig) -> None: + super().__init__(config=config) diff --git a/megatron/core/models/gpt/__init__.py b/megatron/core/models/gpt/__init__.py index 2d5eb8674f..8bbecfcb09 100644 --- a/megatron/core/models/gpt/__init__.py +++ b/megatron/core/models/gpt/__init__.py @@ -1 +1,2 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from .gpt_model import GPTModel diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 9d3f6dcd4d..38e530c6da 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -1,116 +1,384 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import warnings +from typing import Optional + from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.fusions.fused_layer_norm import FusedLayerNorm +from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.custom_layers.transformer_engine import ( - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TERowParallelLinear, -) from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.multi_latent_attention import ( + MLASelfAttention, + MLASelfAttentionSubmodules, +) from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.switch_mlp import SwitchMLP -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules - -# Use this spec to use lower level Transformer Engine modules (required for fp8 training) -gpt_layer_with_transformer_engine_spec = ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - dot_product_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), +from megatron.core.transformer.transformer_block import ( + TransformerBlockSubmodules, + get_num_layers_to_build, +) +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import ( + TransformerLayer, + TransformerLayerSubmodules, + get_transformer_layer_offset, ) +from megatron.core.utils import is_te_min_version + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TELinear, + TENorm, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + import apex # pylint: disable=unused-import + + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm -# Use this spec for an implementation using only modules in megatron core -gpt_layer_local_spec = ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=FusedLayerNorm, - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - dot_product_attention=DotProductAttention, - linear_proj=RowParallelLinear, + HAVE_APEX = True + LNImpl = FusedLayerNorm +except ImportError: + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + warnings.warn('Apex is not installed. Falling back to Torch Norm') + LNImpl = WrappedTorchNorm + + +def get_gpt_layer_with_transformer_engine_spec( + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + qk_layernorm: Optional[bool] = False, + multi_latent_attention: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, +) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + + + Args: + num_experts (int, optional): Number of experts. Defaults to None. + moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. + qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. + fp8 (str, optional): Deprecated. For temporary Nemo compatibility. + moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. + Defaults to False. + + Returns: + ModuleSpec: Module specification with TE modules + """ + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated' + ' and will be removed soon. Please update your code accordingly.' + ) + + mlp = get_mlp_module_spec( + use_te=True, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + if multi_latent_attention: + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=TEColumnParallelLinear, + linear_q_down_proj=TELinear, + linear_q_up_proj=( + TELayerNormColumnParallelLinear + if qk_layernorm + else TEColumnParallelLinear + ), + linear_kv_down_proj=TELinear, + linear_kv_up_proj=( + TELayerNormColumnParallelLinear + if qk_layernorm + else TEColumnParallelLinear + ), + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=FusedLayerNorm, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear, + ) + else: + + # TENorm significantly harms convergence when used + # for QKLayerNorm if TE Version < 1.9; + # we instead use the Apex implementation. + qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=qk_norm if qk_layernorm else IdentityOp, + k_layernorm=qk_norm if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ), - ), - mlp_bda=get_bias_dropout_add, - ), -) + ) + + +def get_gpt_layer_local_spec( + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + qk_layernorm: Optional[bool] = False, + multi_latent_attention: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, +) -> ModuleSpec: + """Use this spec for an implementation using only modules in Megatron-Core. + + + Args: + num_experts (int, optional): Number of experts. Defaults to None. + moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. + qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. + fp8 (str, optional): Deprecated. For temporary Nemo compatibility. + moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. + Defaults to False. + + Returns: + ModuleSpec: Module specification with Megatron-Core modules + """ + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated' + ' and will be removed soon. Please update your code accordingly.' + ) + + mlp = get_mlp_module_spec( + use_te=False, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) -# Use this spec to use lower level Transformer Engine modules and SwitchMLP based MoE -gpt_layer_with_transformer_engine_spec_moe = ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - dot_product_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, + if multi_latent_attention: + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=LNImpl, + self_attention=ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=ColumnParallelLinear, + linear_q_down_proj=ColumnParallelLinear, + linear_q_up_proj=ColumnParallelLinear, + linear_kv_down_proj=ColumnParallelLinear, + linear_kv_up_proj=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=LNImpl if qk_layernorm else IdentityOp, + kv_layernorm=LNImpl if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=LNImpl, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=FusedLayerNorm, - mlp=ModuleSpec( - module=SwitchMLP, # MOE - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear, + ) + else: + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=LNImpl, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=LNImpl if qk_layernorm else IdentityOp, + k_layernorm=LNImpl if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=LNImpl, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map={ + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + }, ), - ), - mlp_bda=get_bias_dropout_add, - ), -) + ) -# Use this spec for an implementation using only modules in megatron core for MoE models -gpt_layer_local_spec_moe = ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=FusedLayerNorm, - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - dot_product_attention=DotProductAttention, - linear_proj=RowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=FusedLayerNorm, - mlp=ModuleSpec( - module=SwitchMLP, # MOE + +def _get_mlp_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, +): + warnings.warn( + """This private function is on a deprecation track. Please switch to `get_mlp_module_spec` + since it will be removed in a future release.""" + ) + + return get_mlp_module_spec( + use_te=use_te, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + fp8=fp8, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + +def get_mlp_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, +) -> ModuleSpec: + """Helper function to get module spec for MLP/MoE""" + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "_get_mlp_module_spec" has been deprecated' + ' and will be removed soon. Please update your code accordingly.' + ) + + if num_experts is None: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear, + linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, ), - ), - mlp_bda=get_bias_dropout_add, - ), -) + ) + else: + # Mixture of experts with modules in megatron core. + return get_moe_module_spec( + use_te=use_te, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + +def get_gpt_decoder_block_spec( + config: TransformerConfig, use_transformer_engine: bool +) -> TransformerBlockSubmodules: + """GPT block spec.""" + if use_transformer_engine: + layer_norm_impl = TENorm + else: + layer_norm_impl = LNImpl + + # Layer specs. + dense_layer_spec = ( + get_gpt_layer_with_transformer_engine_spec( + num_experts=None, + moe_grouped_gemm=False, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + ) + if use_transformer_engine + else get_gpt_layer_local_spec( + num_experts=None, + moe_grouped_gemm=False, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + ) + ) + moe_layer_spec = ( + get_gpt_layer_with_transformer_engine_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + ) + if use_transformer_engine + else get_gpt_layer_local_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + ) + ) + + # Parse config.moe_layer_freq to determine the pattern of expert/dense layers. + # 0 stands for dense layers, 1 stands for expert layers. + # For integer N: Creates a pattern with one expert layer every N layers. + # For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense). + if isinstance(config.moe_layer_freq, int): + moe_layer_pattern = [ + 1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.num_layers) + ] + elif isinstance(config.moe_layer_freq, list): + moe_layer_pattern = config.moe_layer_freq + assert len(moe_layer_pattern) == config.num_layers, ( + f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, " + f"expected {config.num_layers}, " + f"current moe layer pattern: {config.moe_layer_freq}" + ) + else: + raise ValueError( + f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}" + ) + + # Create the layer specs for the model. + layer_specs = [] + for layer_number in range(config.num_layers): + if moe_layer_pattern[layer_number] == 1: + layer_specs.append(moe_layer_spec) + elif moe_layer_pattern[layer_number] == 0: + layer_specs.append(dense_layer_spec) + else: + raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}") + + # Slice the layer specs to only include the layers that are built in this pipeline stage. + # Note: MCore layer_number starts at 1 + offset = get_transformer_layer_offset(config) + num_layers_to_build = get_num_layers_to_build(config) + layer_specs = layer_specs[offset : offset + num_layers_to_build] + + # Block spec. + block_spec = TransformerBlockSubmodules(layer_specs=layer_specs, layer_norm=layer_norm_impl) + + return block_spec diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index d5a9f7de48..beab46a289 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -1,38 +1,64 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import logging -from typing import Literal, Optional, Union +from collections import OrderedDict +from typing import Dict, Literal, Optional import torch from torch import Tensor -from megatron.core import parallel_state, tensor_parallel +from megatron.core import InferenceParams, tensor_parallel +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding -from megatron.core.models.common.embeddings.language_module.language_module import LanguageModule from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding -from megatron.core.transformer.enums import AttnMaskType, ModelType +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import ModelType from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint class GPTModel(LanguageModule): """GPT Transformer language model. Args: - config (TransformerConfig): Transformer config - transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers - vocab_size (int): Vocabulary size - max_sequence_length (int): maximum size of sequence. This is used for positional embedding - pre_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True. - post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True. - fp16_lm_cross_entropy (bool, optional): Defaults to False. - parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor parallel ranks. Defaults to True. - share_embeddings_and_output_weights (bool, optional): When True, input embeddings and output logit weights are shared. Defaults to False. - position_embedding_type (Literal[learned_absolute,rope], optional): Position embedding type.. Defaults to 'learned_absolute'. - rotary_percent (float, optional): Percent of rotary dimension to use for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0. - seq_len_interpolation_factor (Optional[float], optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None. + config (TransformerConfig): + Transformer config + transformer_layer_spec (ModuleSpec): + Specifies module to use for transformer layers + vocab_size (int): + Vocabulary size + max_sequence_length (int): + maximum size of sequence. This is used for positional embedding + pre_process (bool, optional): + Include embedding layer (used with pipeline parallelism). Defaults to True. + post_process (bool, optional): + Include an output layer (used with pipeline parallelism). Defaults to True. + fp16_lm_cross_entropy (bool, optional): + Defaults to False. + parallel_output (bool, optional): + Do not gather the outputs, keep them split across tensor + parallel ranks. Defaults to True. + share_embeddings_and_output_weights (bool, optional): + When True, input embeddings and output logit weights are shared. Defaults to False. + position_embedding_type (Literal[learned_absolute,rope], optional): + Position embedding type.. Defaults to 'learned_absolute'. + rotary_percent (float, optional): + Percent of rotary dimension to use for rotary position embeddings. + Ignored unless position_embedding_type is 'rope'. Defaults to 1.0. + rotary_base (int, optional): + Base period for rotary position embeddings. Ignored unless + position_embedding_type is 'rope'. + Defaults to 10000. + rope_scaling (bool, optional): Toggle RoPE scaling. + rope_scaling_factor (float): RoPE scaling factor. Default 8. + scatter_embedding_sequence_parallel (bool, optional): + Whether embeddings should be scattered across sequence parallel + region or not. Defaults to True. + seq_len_interpolation_factor (Optional[float], optional): + scale of linearly interpolating RoPE for longer sequences. + The value must be a float larger than 1.0. Defaults to None. """ def __init__( @@ -46,13 +72,19 @@ def __init__( fp16_lm_cross_entropy: bool = False, parallel_output: bool = True, share_embeddings_and_output_weights: bool = False, - position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute', + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute', rotary_percent: float = 1.0, + rotary_base: int = 10000, + rope_scaling: bool = False, + rope_scaling_factor: float = 8.0, + scatter_embedding_sequence_parallel: bool = True, seq_len_interpolation_factor: Optional[float] = None, ) -> None: super().__init__(config=config) - self.config: TransformerConfig = config + if has_config_logger_enabled(config): + log_config_to_disk(config, locals(), prefix=type(self).__name__) + self.transformer_layer_spec: ModuleSpec = transformer_layer_spec self.vocab_size = vocab_size self.max_sequence_length = max_sequence_length @@ -67,30 +99,61 @@ def __init__( # TODO: remove this dependency ? self.model_type = ModelType.encoder_or_decoder + # These 4 attributes are needed for TensorRT-LLM export. + self.max_position_embeddings = max_sequence_length + self.rotary_percent = rotary_percent + self.rotary_base = rotary_base + self.rotary_scaling = rope_scaling + if self.pre_process: self.embedding = LanguageModelEmbedding( config=self.config, vocab_size=self.vocab_size, max_sequence_length=self.max_sequence_length, position_embedding_type=position_embedding_type, + scatter_to_sequence_parallel=scatter_embedding_sequence_parallel, ) - if self.position_embedding_type == 'rope': + if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: self.rotary_pos_emb = RotaryEmbedding( - self.config.kv_channels, rotary_percent, seq_len_interpolation_factor + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + rope_scaling=rope_scaling, + rope_scaling_factor=rope_scaling_factor, + use_cpu_initialization=self.config.use_cpu_initialization, ) + # Cache for RoPE tensors which do not change between iterations. + self.rotary_pos_emb_cache = {} + # Transformer. self.decoder = TransformerBlock( config=self.config, - transformer_layer_spec=self.transformer_layer_spec, - self_attn_mask_type=AttnMaskType.causal, + spec=transformer_layer_spec, pre_process=self.pre_process, post_process=self.post_process, ) # Output if post_process: + if self.config.defer_embedding_wgrad_compute: + # The embedding activation buffer preserves a reference to the input activations + # of the final embedding projection layer GEMM. It will hold the activations for + # all the micro-batches of a global batch for the last pipeline stage. Once we are + # done with all the back props for all the microbatches for the last pipeline stage, + # it will be in the pipeline flush stage. During this pipeline flush we use the + # input activations stored in embedding activation buffer and gradient outputs + # stored in gradient buffer to calculate the weight gradients for the embedding + # final linear layer. + self.embedding_activation_buffer = [] + self.grad_output_buffer = [] + else: + self.embedding_activation_buffer = None + self.grad_output_buffer = None + self.output_layer = tensor_parallel.ColumnParallelLinear( config.hidden_size, self.vocab_size, @@ -101,10 +164,33 @@ def __init__( gather_output=not self.parallel_output, skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights, + embedding_activation_buffer=self.embedding_activation_buffer, + grad_output_buffer=self.grad_output_buffer, + ) + + if self.pre_process or self.post_process: + self.setup_embeddings_and_output_layer() + + if has_config_logger_enabled(self.config): + log_config_to_disk( + self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt' ) - if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process): - self.initialize_last_stage_with_word_embeddings() + def set_input_tensor(self, input_tensor: Tensor) -> None: + """Sets input tensor to the model. + + See megatron.model.transformer.set_input_tensor() + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' + self.decoder.set_input_tensor(input_tensor[0]) def forward( self, @@ -113,13 +199,20 @@ def forward( attention_mask: Tensor, decoder_input: Tensor = None, labels: Tensor = None, - inference_params=None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, ) -> Tensor: """Forward function of the GPT Model This function passes the input tensors through the embedding layer, and then the decoeder and finally into the post processing layer (optional). It either returns the Loss values if labels are given or the final hidden units + + Args: + runtime_gather_output (bool): Gather output at runtime. Default None means + `parallel_output` arg in the constructor will be used. """ # If decoder_input is provided (not None), then input_ids and position_ids are ignored. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. @@ -136,11 +229,36 @@ def forward( # Rotary positional embeddings (embedding is None for PP intermediate devices) rotary_pos_emb = None - if self.position_embedding_type == 'rope': - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_params, self.decoder, decoder_input, self.config + rotary_pos_cos = None + rotary_pos_sin = None + if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: + if not self.training and self.config.flash_decode and inference_params: + # Flash decoding uses precomputed cos and sin for RoPE + rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( + inference_params.max_sequence_length, + self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length), + ) + else: + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config, packed_seq_params + ) + rotary_pos_emb = self.rotary_pos_emb( + rotary_seq_len, + packed_seq=packed_seq_params is not None + and packed_seq_params.qkv_format == 'thd', + ) + if ( + (self.config.enable_cuda_graph or self.config.flash_decode) + and rotary_pos_cos is not None + and inference_params + ): + sequence_len_offset = torch.tensor( + [inference_params.sequence_len_offset] * inference_params.current_batch_size, + dtype=torch.int32, + device=rotary_pos_cos.device, # Co-locate this with the rotary tensors ) - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + else: + sequence_len_offset = None # Run decoder. hidden_states = self.decoder( @@ -148,6 +266,11 @@ def forward( attention_mask=attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **(extra_block_kwargs or {}), ) if not self.post_process: @@ -157,7 +280,21 @@ def forward( output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - logits, _ = self.output_layer(hidden_states, weight=output_weight) + logits, _ = self.output_layer( + hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output + ) + + if has_config_logger_enabled(self.config): + payload = OrderedDict( + { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'decoder_input': decoder_input, + 'logits': logits, + } + ) + log_config_to_disk(self.config, payload, prefix='input_and_logits') if labels is None: # [s b h] => [b s h] @@ -167,66 +304,28 @@ def forward( return loss - def shared_embedding_or_output_weight(self) -> Tensor: - """Function to share the input embeddings and output logit weights. + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None + ) -> ShardedStateDict: + """Sharded state dict implementation for GPTModel backward-compatibility + (removing extra state). + + Args: + prefix (str): Module name prefix. + sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. + metadata (Optional[Dict]): metadata controlling sharded state dict creation. Returns: - Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight + ShardedStateDict: sharded state dict for the GPTModel """ - if self.pre_process: - return self.embedding.word_embeddings.weight - elif self.post_process: - return self.output_layer.weight - return None - - def sharded_state_dict(self, prefix: str = '') -> dict: - sharded_state_dict = {} - - if self.pre_process: - embedding_prefix = f'{prefix}embedding.' - embedding_sharded_state_dict = self.embedding.sharded_state_dict( - prefix=embedding_prefix - ) - sharded_state_dict.update(embedding_sharded_state_dict) - - decoder_prefix = f'{prefix}decoder.' - decoder_sharded_state_dict = self.decoder.sharded_state_dict(prefix=decoder_prefix) - sharded_state_dict.update(decoder_sharded_state_dict) - - if self.post_process: - output_layer_prefix = f'{prefix}output_layer.' - output_layer_key = f'{output_layer_prefix}weight' - if self.share_embeddings_and_output_weights: - if not self.pre_process: - # when sharing embeddings with last stage, we need to use the weights from the first stage - # on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight - tensor = self.shared_embedding_or_output_weight() - first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight' - last_stage_word_emb_replica_id = ( - 1, # copy of first stage embedding - 0, - parallel_state.get_data_parallel_rank(), - ) - - sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint( - tensor=tensor, - key=first_stage_word_emb_key, - replica_id=last_stage_word_emb_replica_id, - allow_shape_mismatch=True, - ) - - sharded_state_dict[output_layer_key] = sharded_output_layer_tensor - - else: - output_layer_state_dict = self.output_layer.state_dict( - prefix=output_layer_prefix, keep_vars=True - ) - output_layer_tensor = output_layer_state_dict[output_layer_key] - # independent output layer - sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint( - tensor=output_layer_tensor, key=output_layer_key, allow_shape_mismatch=True, - ) - - sharded_state_dict[output_layer_key] = sharded_output_layer_tensor + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + output_layer_extra_state_key = f'{prefix}output_layer._extra_state' + + # Old GPT checkpoints only stored the output layer weight key. So we remove the + # _extra_state key but check that it doesn't contain any data anyway + output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None) + assert not ( + output_extra_state and output_extra_state.data + ), f'Expected output layer extra state to be empty, got: {output_extra_state}' return sharded_state_dict diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py new file mode 100755 index 0000000000..513eeddc7e --- /dev/null +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -0,0 +1,81 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import warnings +from typing import Optional + +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.mlp import MLPSubmodules +from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.utils import get_te_version, is_te_min_version + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelGroupedLinear, + TEColumnParallelLinear, + TERowParallelGroupedLinear, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + + +def get_moe_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + moe_use_legacy_grouped_gemm: Optional[bool] = False, +) -> ModuleSpec: + """Helper function to get module spec for MoE""" + assert num_experts is not None + + mlp = MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ) + + # experts spec + if moe_grouped_gemm: + ## use GroupedMLP + if use_te and TEColumnParallelGroupedLinear is not None and not moe_use_legacy_grouped_gemm: + ## use TEGroupedLinear + expert_module = TEGroupedMLP + expert_submodule = MLPSubmodules( + linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear + ) + else: + ## use legacy GroupedMLP + expert_module = GroupedMLP + expert_submodule = None + warnings.warn( + 'The legacy GroupedMLP will be deprecated in Megatron-Core v0.12.0. ' + 'Please update the TransformerEngine to version>=1.7.0 and use TEGroupedMLP.' + ) + else: + ## use SequentialMLP + expert_module = SequentialMLP + if use_te and not is_te_min_version("1.7.0.dev0"): + warnings.warn( + "Only transformer-engine>=1.7.0 supports MoE experts, " + f"but your version is {get_te_version()}. Use local linear implementation instead." + ) + expert_submodule = MLPSubmodules( + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear + ) + else: + expert_submodule = mlp + + experts = ModuleSpec(module=expert_module, submodules=expert_submodule) + + # shared experts spec + shared_experts = ModuleSpec(module=SharedExpertMLP, params={"gate": False}, submodules=mlp) + + # MoE module spec + moe_module_spec = ModuleSpec( + module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts) + ) + return moe_module_spec diff --git a/megatron/core/models/huggingface/__init__.py b/megatron/core/models/huggingface/__init__.py new file mode 100644 index 0000000000..d5ad39d593 --- /dev/null +++ b/megatron/core/models/huggingface/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from .module import HuggingFaceModule, build_hf_model diff --git a/megatron/core/models/huggingface/clip_model.py b/megatron/core/models/huggingface/clip_model.py new file mode 100644 index 0000000000..f1522e9653 --- /dev/null +++ b/megatron/core/models/huggingface/clip_model.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from transformers import AutoModel + +from megatron.core.models.huggingface import HuggingFaceModule + + +class ClipHuggingFaceModel(HuggingFaceModule): + """ + Wrapper for CLIP HuggingFace models + """ + + def __init__(self, config): + super().__init__(config) + self.model = AutoModel.from_pretrained(config.huggingface_model_name_or_path) + + def forward(self, *args, **kwargs): + """Forward function""" + x = self.model(*args, **kwargs) + x = x['last_hidden_state'] + + return x diff --git a/megatron/core/models/huggingface/module.py b/megatron/core/models/huggingface/module.py new file mode 100644 index 0000000000..823925ed02 --- /dev/null +++ b/megatron/core/models/huggingface/module.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from transformers import AutoConfig, AutoModel + +from megatron.core.transformer.module import MegatronModule + + +class HuggingFaceModule(MegatronModule): + """ + Basic module for huggingface + """ + + def __init__(self, config): + super().__init__(config=config) + + def set_input_tensor(self, input_tensor): + """Dummy function for set_input_tensor""" + self.input_tensor = input_tensor + + +class AutoHuggingFaceModel(HuggingFaceModule): + """ + Wrapper for HuggingFace AutoModel + """ + + def __init__(self, config): + super().__init__(config) + self.model = AutoModel.from_pretrained(config.huggingface_model_name_or_path) + + def forward(self, *args, **kwargs): + """Forward function""" + return self.model(*args, **kwargs) + + +def build_hf_model(config): + """Builds huggingface wrapper model given config""" + hf_config = AutoConfig.from_pretrained(config.huggingface_model_name_or_path) + + if "qwen" in hf_config.model_type: + from megatron.core.models.huggingface.qwen_model import QwenHuggingFaceModel + + model = QwenHuggingFaceModel(config) + elif "vit" in hf_config.model_type: + from megatron.core.models.huggingface.clip_model import ClipHuggingFaceModel + + model = ClipHuggingFaceModel(config) + else: + raise NotImplementedError(f"Huggingface model type {hf_config.model_type} is not supported") + + return model diff --git a/megatron/core/models/huggingface/qwen_model.py b/megatron/core/models/huggingface/qwen_model.py new file mode 100644 index 0000000000..216eeaa87f --- /dev/null +++ b/megatron/core/models/huggingface/qwen_model.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from transformers.models.qwen2 import Qwen2ForCausalLM + +from megatron.core.models.huggingface import HuggingFaceModule + + +class QwenHuggingFaceModel(HuggingFaceModule): + """ + Wrapper for Qwen LM HuggingFace models + """ + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2ForCausalLM.from_pretrained(config.huggingface_model_name_or_path) + + def forward(self, *args, **kwargs): + """Forward function""" + combined_embeddings = kwargs['decoder_input'].permute(1, 0, 2) + x = self.model( + position_ids=None, # TODO: I guess we're just assuming no custom pos ids + attention_mask=kwargs['attention_mask'], + inputs_embeds=combined_embeddings, + labels=kwargs['labels'], + ) + + if kwargs['labels'] is not None: + x = x["loss"] + else: + x = x["logits"] + + return x + + def embedding(self, input_ids, position_ids=None): + """Function to run process tokens with input embeddings""" + return self.model.get_input_embeddings()(input_ids).transpose(1, 0).contiguous() diff --git a/megatron/core/models/mamba/__init__.py b/megatron/core/models/mamba/__init__.py new file mode 100644 index 0000000000..5aaf852401 --- /dev/null +++ b/megatron/core/models/mamba/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from .mamba_model import MambaModel diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py new file mode 100755 index 0000000000..e5fa9efa72 --- /dev/null +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.extensions.transformer_engine import ( + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TERowParallelLinear, +) +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules +from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules +from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +mamba_stack_spec = ModuleSpec( + module=MambaStack, + submodules=MambaStackSubmodules( + mamba_layer=ModuleSpec( + module=MambaLayer, + submodules=MambaLayerSubmodules( + mixer=ModuleSpec( + module=MambaMixer, + submodules=MambaMixerSubmodules( + in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear + ), + ), + mamba_bda=get_bias_dropout_add, + ), + ), + # Started with spec from gpt_layer_specs.py (with MLP removed) + # Using the TE spec because we had problems getting the non-TE spec + # working + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + self_attn_bda=get_bias_dropout_add, + ), + ), + # Started with spec from gpt_layer_specs.py + # Using the TE spec because we had problems getting the non-TE spec + # working + mlp_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + ), +) diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py new file mode 100644 index 0000000000..5794b1b41a --- /dev/null +++ b/megatron/core/models/mamba/mamba_model.py @@ -0,0 +1,228 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from typing import Literal, Optional + +from torch import Tensor + +from megatron.core import InferenceParams, tensor_parallel +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig + + +class MambaModel(LanguageModule): + """Mamba language model. + + Args: + config (TransformerConfig): Transformer config + mamba_stack_spec (ModuleSpec): Specifies the modules to use for the various layer types + vocab_size (int): Vocabulary size + max_sequence_length (int): maximum size of sequence. + This is used for positional embedding + pre_process (bool, optional): Include embedding layer + (used with pipeline parallelism). Defaults to True. + mamba_ssm_ngroups (int, optional): Specifies the number of groups to use. + The default value is 8, as in the NVIDIA Mamba2 (pure and hybrid) 8b. + However, in the original Mamba2 paper, the checkpoints use a setting of 1. + Defaults to 8. + hybrid_attention_ratio (float, optional): The target ratio of attention + layers to total layers + hybrid_mlp_ratio (float, optional): The target ratio of mlp layers to total layers + hybrid_override_pattern (str, optional): The hybrid layer pattern to override with + post_process (bool, optional): Include an output layer (used with pipeline parallelism). + Defaults to True. + fp16_lm_cross_entropy (bool, optional): Defaults to False. + parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor + parallel ranks. Defaults to True. + share_embeddings_and_output_weights (bool, optional): When True, input embeddings and + output logit weights are shared. Defaults to False. + position_embedding_type (Literal[learned_absolute,rope,none], optional): Position + embedding type. Defaults to 'none'. + rotary_percent (float, optional): Percent of rotary dimension to use for rotary position + embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0. + rotary_base (int, optional): Base period for rotary position embeddings. Ignored unless + position_embedding_type is 'rope'. Defaults to 10000. + seq_len_interpolation_factor (Optional[float], optional): scale of linearly + interpolating RoPE for longer sequences. The value must be a float larger than 1.0. + Defaults to None. + """ + + def __init__( + self, + config: TransformerConfig, + mamba_stack_spec: ModuleSpec, + vocab_size: int, + max_sequence_length: int, + mamba_ssm_ngroups: int = 8, + pre_process: bool = True, + hybrid_attention_ratio: float = 0.0, + hybrid_mlp_ratio: float = 0.0, + hybrid_override_pattern: str = None, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + # Mamba with no attention has no need for position embeddings, so none is default + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'none', + rotary_percent: float = 1.0, + rotary_base: int = 10000, + seq_len_interpolation_factor: Optional[float] = None, + ) -> None: + super().__init__(config=config) + + if has_config_logger_enabled(config): + log_config_to_disk(config, locals(), prefix=type(self).__name__) + + self.mamba_stack_spec: ModuleSpec = mamba_stack_spec + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.mamba_ssm_ngroups = mamba_ssm_ngroups + self.pre_process = pre_process + self.hybrid_attention_ratio = hybrid_attention_ratio + self.hybrid_mlp_ratio = hybrid_mlp_ratio + self.hybrid_override_pattern = hybrid_override_pattern + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.position_embedding_type = position_embedding_type + + # megatron core pipelining currently depends on model type + # TODO: remove this dependency ? + self.model_type = ModelType.encoder_or_decoder + + if self.pre_process: + self.embedding = LanguageModelEmbedding( + config=self.config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + position_embedding_type=position_embedding_type, + ) + + if self.position_embedding_type == 'rope': + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + use_cpu_initialization=self.config.use_cpu_initialization, + ) + + self.decoder = build_module( + mamba_stack_spec, + self.config, + mamba_ssm_ngroups=self.mamba_ssm_ngroups, + pre_process=self.pre_process, + hybrid_attention_ratio=self.hybrid_attention_ratio, + hybrid_mlp_ratio=self.hybrid_mlp_ratio, + hybrid_override_pattern=self.hybrid_override_pattern, + post_process=self.post_process, + dtype=config.params_dtype, + ) + + # Output + if post_process: + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.pre_process + and self.share_embeddings_and_output_weights, + ) + + if self.pre_process or self.post_process: + self.setup_embeddings_and_output_layer() + + def set_input_tensor(self, input_tensor: Tensor) -> None: + """Sets input tensor to the model. + + See megatron.model.transformer.set_input_tensor() + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' + self.decoder.set_input_tensor(input_tensor[0]) + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_params: InferenceParams = None, + ) -> Tensor: + """Forward function of the Mamba model. This function passes the input tensors + through the embedding layer, and then the decoder and finally into the post + processing layer (optional). + + It either returns the Loss values if labels are given or the final hidden units + """ + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # The following assert will currently fail when running inference. + # Commented out for now. + # TODO (duncan/rwaleffe): (1) confirm that the externally-generated + # attention mask is not needed and is ignored by the model in + # inference mode, (2) reduce the size of the externally-generated + # attention mask to prevent CPU OOM (as we did for training), (3) + # force the attention mask passed to the model in inference mode to + # be None, so this assert will succeed. + # assert attention_mask is None, "The attention mask is ignored and should be set to None" + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) + + if not self.post_process: + return hidden_states + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits, _ = self.output_layer(hidden_states, weight=output_weight) + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + loss = self.compute_language_model_loss(labels, logits) + + return loss diff --git a/megatron/core/models/multimodal/__init__.py b/megatron/core/models/multimodal/__init__.py new file mode 100644 index 0000000000..f8011007a5 --- /dev/null +++ b/megatron/core/models/multimodal/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/models/multimodal/context_parallel.py b/megatron/core/models/multimodal/context_parallel.py new file mode 100644 index 0000000000..1cda5994a0 --- /dev/null +++ b/megatron/core/models/multimodal/context_parallel.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Multimodal Sequence Parallel (SP) and Context Parallel (CP) functionality.""" + +import torch + +from megatron.core.packed_seq_params import PackedSeqParams + + +def get_padding( + seq_len, cp_size, tp_size, has_sp, decoder_tp_comm_overlap=False, decoder_seq_len=None +): + """Calculate padding needed for SP and/or CP. + + Args: + seq_len (int): Model sequence length. + cp_size (int): Context parallel size. + tp_size (int): Tensor parallel size. + has_sp (bool): Model uses sequence parallelism. + decoder_tp_comm_overlap (bool): Decoder (LLM) uses tensor parallel communication overlap. + decoder_seq_len (int): Decoder (LLM) maximum sequence length. + + Returns: + padding (int): Padding needed given model configuration. + """ + + padding = 0 + # TP Comm overlap is performed with combined text+image embeddings. + if has_sp and decoder_tp_comm_overlap: + # If TP Comm Overlap is enabled for combined text+image embedding in LM backbone, + # user needs to provide decoder_seq_len with any potential padding needed for SP+CP + assert ( + decoder_seq_len is not None + ), "Please provide decoder seq length when using TP comm overlap for LM backbone" + padding = decoder_seq_len - seq_len + elif has_sp or cp_size > 1: + padding_factor = 1 + if has_sp and cp_size > 1: + # Padding to multiple of tp_size * cp_size * 2 when using CP + SP. + padding_factor = tp_size * cp_size * 2 + elif cp_size > 1: + padding_factor = cp_size * 2 + elif has_sp: + padding_factor = tp_size + + padding = int((seq_len + padding_factor - 1) // padding_factor * padding_factor) - seq_len + + return padding + + +def get_packed_seq_params(tokens, img_seq_len, padding_needed, cp_size, use_packed_sequence=False): + """Get PackedSeqParams for CP. + + Args: + tokens (torch.Tensor): [batch, seq_len] input tokens. + img_seq_len (int): Image sequence length. + padding_needed (int): Padding to add. + cp_size (int): Context parallel size. + use_packed_sequence (bool): Uses sequence packing. + + Returns: + packed_seq_params (PackedSeqParams): Parameters to be sent to Transformer Engine. + """ + batch_size = tokens.shape[0] + # Calculate the valid token seq len that LM backbone should compute on + combined_valid_seqlen = tokens.shape[1] + img_seq_len - padding_needed + cu_seqlens = torch.arange( + 0, + (batch_size + 1) * (combined_valid_seqlen), + step=(combined_valid_seqlen), + dtype=torch.int32, + device=tokens.device, + ) + # Calculate the total padded token seq len + combined_padded_seqlen = tokens.shape[1] + img_seq_len + cu_seqlens_padded = None + qkv_format = 'sbhd' + if cp_size > 1 and (padding_needed > 0 or use_packed_sequence): + # Provide cu_seqlens__padded for CP support + cu_seqlens_padded = torch.arange( + 0, + (batch_size + 1) * (combined_padded_seqlen), + step=(combined_padded_seqlen), + dtype=torch.int32, + device=tokens.device, + ) + # CP with padding mask type requires THD format + qkv_format = 'thd' + + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=combined_padded_seqlen, + max_seqlen_kv=combined_padded_seqlen, + qkv_format=qkv_format, + ) + + return packed_seq_params diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py new file mode 100644 index 0000000000..f9ae2314b0 --- /dev/null +++ b/megatron/core/models/multimodal/llava_model.py @@ -0,0 +1,958 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import logging +from collections import namedtuple +from functools import partial +from typing import List, Optional + +import torch + +from megatron.core import InferenceParams, tensor_parallel +from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk +from megatron.core.models.gpt import GPTModel +from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_num_image_embeddings +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.models.vision.radio import RADIOViTModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import get_context_parallel_rank, get_context_parallel_world_size +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import log_single_rank + +try: + import transformer_engine # pylint: disable=unused-import + + from megatron.core.extensions.transformer_engine import TEDotProductAttention + from megatron.core.utils import is_te_min_version + + HAVE_TE = True + try: + import transformer_engine_torch as tex + + HAVE_TEX = True + except: + HAVE_TEX = False +except: + HAVE_TE = False + if get_context_parallel_world_size() > 1: + raise RuntimeError("ContextParallelism requires TransformerEngine support, but not found.") + + +IGNORE_INDEX = -100 # ID for labels that should be ignored. +# Image token index can be tokenizer dependent so the default value does not work in all cases. +DEFAULT_IMAGE_TOKEN_INDEX = -200 +IMAGE_TOKEN = "" +VIDEO_TOKEN = "