diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index f0c0b22a..eccfefe0 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -164,17 +164,17 @@ def from_config(cfg: Extract, use_fdsp: bool, cpu_only: bool) -> "LoadedModel": fsdp_port = find_available_port() msg = f"Fully Sharded Data Parallel running on port {fsdp_port}" - layer_cls = get_transformer_layer_cls(model) - if layer_cls is not None: - msg += f" with '{layer_cls.__name__}' wrapping policy" - wrap_policy = ( - partial(transformer_auto_wrap_policy, transformer_layer_cls={layer_cls}) - if layer_cls is not None - else None - ) + # layer_cls = get_transformer_layer_cls(model) + # if layer_cls is not None: + # msg += f" with '{layer_cls.__name__}' wrapping policy" + # wrap_policy = ( + # partial(transformer_auto_wrap_policy, transformer_layer_cls={layer_cls}) + # if layer_cls is not None + # else None + # ) fsdp_model = FullyShardedDataParallel( module=model, - auto_wrap_policy=wrap_policy, + auto_wrap_policy=None, cpu_offload=CPUOffload(offload_params=False), ) print(msg)