From 8c80d9100f814ad1344774443c155891be2dfe64 Mon Sep 17 00:00:00 2001 From: kazuki Date: Fri, 3 Jan 2025 15:09:49 +0900 Subject: [PATCH 1/2] fix: fix type annotation --- megatron/core/extensions/transformer_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 5884109cae..9f5e1f7720 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1108,7 +1108,7 @@ class TEDelayedScaling(te.common.recipe.DelayedScaling): def __init__( self, - config: ModelParallelConfig, + config: TransformerConfig, fp8_format: int, override_linear_precision: tuple = (False, False, False), ): From c0dffe2cf159e9db8e2b0bd1eeb2c2eaea462b79 Mon Sep 17 00:00:00 2001 From: kazuki Date: Fri, 3 Jan 2025 15:15:01 +0900 Subject: [PATCH 2/2] fix: conversion script from hf to mcore for TransformerEngine v1.10+ --- tools/checkpoint/convert.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tools/checkpoint/convert.py b/tools/checkpoint/convert.py index 935613b143..83d694eb37 100644 --- a/tools/checkpoint/convert.py +++ b/tools/checkpoint/convert.py @@ -2,6 +2,7 @@ import argparse import importlib +import torch import torch.multiprocessing as mp import sys @@ -107,6 +108,9 @@ def load_plugin(plugin_type, name): return plugin def main(): + if not torch.cuda.is_initialized(): + torch.cuda.init() + import argparse parser = argparse.ArgumentParser(description="Megatron Checkpoint Converter Arguments", allow_abbrev=False, conflict_handler='resolve') @@ -151,4 +155,5 @@ def main(): if __name__ == '__main__': + mp.set_start_method(method='spawn') main()