Skip to content

Commit

Permalink
Allow BF16 dtype support on CPU (pytorch#1218)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchitintel authored Jul 26, 2024
1 parent 651b18c commit e101420
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/source/deep_dives/recipe_deepdive.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ Initialize recipe state including seed, device, dtype, metric loggers, relevant
def __init__(...):
self._device = utils.get_device(device=params.device)
self._dtype = utils.get_dtype(dtype=params.dtype)
self._dtype = utils.get_dtype(dtype=params.dtype, device=self._device)
...
Load checkpoint, update recipe state from checkpoint, initialize components and load state dicts from checkpoint
Expand Down
2 changes: 1 addition & 1 deletion recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(self, cfg: DictConfig) -> None:

def setup(self) -> None:
self._device = utils.get_device(device=self._cfg.device)
self._dtype = utils.get_dtype(dtype=self._cfg.dtype)
self._dtype = utils.get_dtype(dtype=self._cfg.dtype, device=self._device)
self._limit = self._cfg.limit
self._tasks = list(self._cfg.tasks)
self._quantizer = config.instantiate(self._cfg.quantizer)
Expand Down
2 changes: 1 addition & 1 deletion recipes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class InferenceRecipe:

def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(dtype=cfg.dtype)
self._dtype = utils.get_dtype(dtype=cfg.dtype, device=self._device)
self._quantizer = config.instantiate(cfg.quantizer)
self._quantization_mode = utils.get_quantizer_mode(self._quantizer)

Expand Down
2 changes: 1 addition & 1 deletion recipes/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class QuantizationRecipe:

def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(dtype=cfg.dtype)
self._dtype = utils.get_dtype(dtype=cfg.dtype, device=self._device)
self._quantizer = config.instantiate(cfg.quantizer)
self._quantization_mode = utils.get_quantizer_mode(self._quantizer)
utils.set_seed(seed=cfg.seed)
Expand Down
2 changes: 0 additions & 2 deletions torchtune/utils/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ def get_dtype(
f"Dtype {torch_dtype} must be one of {', '.join(list(PRECISION_STR_TO_DTYPE.keys()))} for finetuning."
)

# TODO (rohan-varma): prefer to use get_default_device() here to figure out whether user is training on
# CPU or GPU, but it is not supported in versions of torch we test.
if (
torch_dtype == torch.bfloat16
and device != torch.device("cpu")
Expand Down

0 comments on commit e101420

Please sign in to comment.