Skip to content

Commit

Permalink
[Misc] Add assertion and helpful message for marlin24 compressed mode…
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka authored Dec 23, 2024
1 parent 2e72668 commit b866cdb
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

assert params_dtype == torch.float16, (
"float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501
)

pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes)

Expand Down

0 comments on commit b866cdb

Please sign in to comment.