Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 238e8f9

Browse files
committed
update asserts with more info
1 parent 4ca6ddf commit 238e8f9

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,16 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
210210
scale_fn_recipes.add(child.recipe.scale_fn_name)
211211

212212
# TODO This way to get the activation dtype is not ideal
213-
assert len(x_dtypes) == 1, "All layers must have the same last seen input_dtype"
213+
if len(x_dtypes) != 1:
214+
raise ValueError(
215+
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
216+
)
214217
x_dtype = next(iter(x_dtypes))
215218

216-
assert len(scale_fn_recipes) == 1, "All layers must have the same scale_fn recipe"
219+
if len(scale_fn_recipes) != 1:
220+
raise ValueError(
221+
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
222+
)
217223
scale_fn_recipe = next(iter(scale_fn_recipes))
218224

219225
assert (

0 commit comments

Comments
 (0)