Skip to content

Commit

Permalink
try to use fdsp
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed Apr 26, 2023
1 parent 8a3db35 commit 5e66154
Showing 1 changed file with 66 additions and 8 deletions.
74 changes: 66 additions & 8 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,21 @@
import os
from copy import copy
from dataclasses import InitVar, dataclass
from functools import partial
from itertools import islice
from typing import Any, Iterable, Literal
from typing import Any, Iterable, Literal, Optional
from warnings import filterwarnings

from torch.distributed.elastic.utils import get_socket_with_port
from torch.distributed.fsdp import (
CPUOffload,
FullyShardedDataParallel as FSDP,
FullyShardedDataParallel,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from typing import Any, Callable, Iterable, Type, cast
import torch
from datasets import (
Array2D,
Expand Down Expand Up @@ -95,6 +106,31 @@ def explode(self) -> list["Extract"]:
return copies


@dataclass(kw_only=True)
class FSDPOptions:
port: int


def find_available_port() -> int:
s = get_socket_with_port()
_, port = s.getsockname()
s.close()

return port


def get_transformer_layer_cls(model: torch.nn.Module) -> Type[torch.nn.Module] | None:
"""Get the class of the transformer layer used by the given model."""
total_params = sum(p.numel() for p in model.parameters())
for module in model.modules():
if isinstance(module, torch.nn.ModuleList):
module_params = sum(p.numel() for p in module.parameters())
if module_params > total_params / 2:
return type(module[0])

return None


@dataclass(kw_only=True)
class LoadedModel:
"""A model and its tokenizer."""
Expand All @@ -103,6 +139,7 @@ class LoadedModel:
tokenizer: PreTrainedTokenizerBase
is_encoder_decoder: bool
has_lm_preds: bool
fsdp_options: Optional[FSDPOptions]

def share_memory(self):
"""Makes the model share memory across processes.
Expand All @@ -112,17 +149,36 @@ def share_memory(self):
"""
self.model.share_memory()

def to_device(self, device: str):
def to_device(self, device: str | torch.device):
"""Moves the model to the specified device."""
self.model.to(device)

@staticmethod
def from_config(cfg: Extract, device: str) -> "LoadedModel":
def from_config(cfg: Extract, use_fdsp: bool, cpu_only: bool) -> "LoadedModel":
model = instantiate_model(
cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32
).to(device)
print(f"Model {cfg.model} loaded on device:", device)
cfg.model, torch_dtype="auto" if not cpu_only else torch.float32
)
print(f"Model {cfg.model} loaded on cpu")
tokenizer = instantiate_tokenizer(cfg.model, truncation_side="left")
if use_fdsp:
fsdp_port = find_available_port()
msg = f"Fully Sharded Data Parallel running on port {fsdp_port}"

layer_cls = get_transformer_layer_cls(model)
if layer_cls is not None:
msg += f" with '{layer_cls.__name__}' wrapping policy"
wrap_policy = (
partial(transformer_auto_wrap_policy, transformer_layer_cls={layer_cls})
if layer_cls is not None
else None
)
fsdp_model = FullyShardedDataParallel(
module=model,
auto_wrap_policy=wrap_policy,
cpu_offload=CPUOffload(offload_params=False),
)
print(msg)
model = fsdp_model

is_enc_dec = model.config.is_encoder_decoder
if is_enc_dec and cfg.use_encoder_states:
Expand All @@ -138,6 +194,7 @@ def from_config(cfg: Extract, device: str) -> "LoadedModel":
tokenizer=tokenizer,
is_encoder_decoder=is_enc_dec,
has_lm_preds=has_lm_preds,
fsdp_options=FSDPOptions(port=123) if use_fdsp else None,
)


Expand All @@ -159,7 +216,8 @@ def extract_hiddens(
filterwarnings("ignore")
logging.disable(logging.CRITICAL)

loaded_model.to_device(device)
if loaded_model.fsdp_options is not None:
loaded_model.to_device(device)

p_cfg = cfg.prompts
ds_names = p_cfg.datasets
Expand Down Expand Up @@ -393,7 +451,7 @@ def get_splits() -> SplitDict:

devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
# Decide where to load the model from - CPU vs one of the GPUs
loaded_model = LoadedModel.from_config(cfg, device=devices[0])
loaded_model = LoadedModel.from_config(cfg, use_fdsp=True, cpu_only=False)
print("Loaded model from config successful")
# Share the model across all processes if we're using multiple GPUs
if len(devices) > 1:
Expand Down

0 comments on commit 5e66154

Please sign in to comment.