From 5e661548ccecbd4ab77fd17eb4ac8b19bcc069a9 Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 26 Apr 2023 21:11:58 +0800 Subject: [PATCH] try to use fdsp --- elk/extraction/extraction.py | 74 ++++++++++++++++++++++++++++++++---- 1 file changed, 66 insertions(+), 8 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 10620bec..f0c0b22a 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -3,10 +3,21 @@ import os from copy import copy from dataclasses import InitVar, dataclass +from functools import partial from itertools import islice -from typing import Any, Iterable, Literal +from typing import Any, Iterable, Literal, Optional from warnings import filterwarnings +from torch.distributed.elastic.utils import get_socket_with_port +from torch.distributed.fsdp import ( + CPUOffload, + FullyShardedDataParallel as FSDP, + FullyShardedDataParallel, +) +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformers import PreTrainedModel +from transformers.modeling_outputs import ModelOutput +from typing import Any, Callable, Iterable, Type, cast import torch from datasets import ( Array2D, @@ -95,6 +106,31 @@ def explode(self) -> list["Extract"]: return copies +@dataclass(kw_only=True) +class FSDPOptions: + port: int + + +def find_available_port() -> int: + s = get_socket_with_port() + _, port = s.getsockname() + s.close() + + return port + + +def get_transformer_layer_cls(model: torch.nn.Module) -> Type[torch.nn.Module] | None: + """Get the class of the transformer layer used by the given model.""" + total_params = sum(p.numel() for p in model.parameters()) + for module in model.modules(): + if isinstance(module, torch.nn.ModuleList): + module_params = sum(p.numel() for p in module.parameters()) + if module_params > total_params / 2: + return type(module[0]) + + return None + + @dataclass(kw_only=True) class LoadedModel: """A model and its tokenizer.""" @@ -103,6 +139,7 @@ class LoadedModel: tokenizer: PreTrainedTokenizerBase is_encoder_decoder: bool has_lm_preds: bool + fsdp_options: Optional[FSDPOptions] def share_memory(self): """Makes the model share memory across processes. @@ -112,17 +149,36 @@ def share_memory(self): """ self.model.share_memory() - def to_device(self, device: str): + def to_device(self, device: str | torch.device): """Moves the model to the specified device.""" self.model.to(device) @staticmethod - def from_config(cfg: Extract, device: str) -> "LoadedModel": + def from_config(cfg: Extract, use_fdsp: bool, cpu_only: bool) -> "LoadedModel": model = instantiate_model( - cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32 - ).to(device) - print(f"Model {cfg.model} loaded on device:", device) + cfg.model, torch_dtype="auto" if not cpu_only else torch.float32 + ) + print(f"Model {cfg.model} loaded on cpu") tokenizer = instantiate_tokenizer(cfg.model, truncation_side="left") + if use_fdsp: + fsdp_port = find_available_port() + msg = f"Fully Sharded Data Parallel running on port {fsdp_port}" + + layer_cls = get_transformer_layer_cls(model) + if layer_cls is not None: + msg += f" with '{layer_cls.__name__}' wrapping policy" + wrap_policy = ( + partial(transformer_auto_wrap_policy, transformer_layer_cls={layer_cls}) + if layer_cls is not None + else None + ) + fsdp_model = FullyShardedDataParallel( + module=model, + auto_wrap_policy=wrap_policy, + cpu_offload=CPUOffload(offload_params=False), + ) + print(msg) + model = fsdp_model is_enc_dec = model.config.is_encoder_decoder if is_enc_dec and cfg.use_encoder_states: @@ -138,6 +194,7 @@ def from_config(cfg: Extract, device: str) -> "LoadedModel": tokenizer=tokenizer, is_encoder_decoder=is_enc_dec, has_lm_preds=has_lm_preds, + fsdp_options=FSDPOptions(port=123) if use_fdsp else None, ) @@ -159,7 +216,8 @@ def extract_hiddens( filterwarnings("ignore") logging.disable(logging.CRITICAL) - loaded_model.to_device(device) + if loaded_model.fsdp_options is not None: + loaded_model.to_device(device) p_cfg = cfg.prompts ds_names = p_cfg.datasets @@ -393,7 +451,7 @@ def get_splits() -> SplitDict: devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem) # Decide where to load the model from - CPU vs one of the GPUs - loaded_model = LoadedModel.from_config(cfg, device=devices[0]) + loaded_model = LoadedModel.from_config(cfg, use_fdsp=True, cpu_only=False) print("Loaded model from config successful") # Share the model across all processes if we're using multiple GPUs if len(devices) > 1: