diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 5884109cae..9f5e1f7720 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1108,7 +1108,7 @@ class TEDelayedScaling(te.common.recipe.DelayedScaling): def __init__( self, - config: ModelParallelConfig, + config: TransformerConfig, fp8_format: int, override_linear_precision: tuple = (False, False, False), ): diff --git a/tools/checkpoint/convert.py b/tools/checkpoint/convert.py index 935613b143..83d694eb37 100644 --- a/tools/checkpoint/convert.py +++ b/tools/checkpoint/convert.py @@ -2,6 +2,7 @@ import argparse import importlib +import torch import torch.multiprocessing as mp import sys @@ -107,6 +108,9 @@ def load_plugin(plugin_type, name): return plugin def main(): + if not torch.cuda.is_initialized(): + torch.cuda.init() + import argparse parser = argparse.ArgumentParser(description="Megatron Checkpoint Converter Arguments", allow_abbrev=False, conflict_handler='resolve') @@ -151,4 +155,5 @@ def main(): if __name__ == '__main__': + mp.set_start_method(method='spawn') main()