Skip to content

Commit

Permalink
Add/fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 6, 2024
1 parent a653be0 commit e01f0d1
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ def _is_supported_module(graph_model: GraphModule, node: Node) -> bool:
# We support only self-attention
if isinstance(module, nn.MultiheadAttention):
kwargs = dict(node.kwargs)
# When using hf/accelerate, we need to check the signature of the original forward
forward_to_check = module._old_forward if hasattr(
module, '_old_forward') else module.forward
kwargs.update(zip(forward_to_check.__code__.co_varnames[1:], node.args))
Expand Down Expand Up @@ -962,6 +963,7 @@ def create_mul_node(self, scale, shape, axis, batch_dim=0):

def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs):
# Check for MHA Cross attention, and if found, skip it
# When using hf/accelerate, we need to check the signature of the original forward
forward_to_check = module._old_forward if hasattr(
module, '_old_forward') else module.forward
kwargs.update(zip(forward_to_check.__code__.co_varnames[1:], args[:-1]))
Expand Down
1 change: 1 addition & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerRowFloat):

@value
def stats_reduce_dim(group_dim):
# If group_dim = -1, we need a workaround to avoid selecting wrong dim
if group_dim == -1:
return -1
else:
Expand Down
1 change: 1 addition & 0 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- transformers
- datasets
- torch_mlir (optional for torch-mlir based export)
- optimum-amd (WIP, install brevitas-compatibility branch)

## Run

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def sharded_weight_group_export(model, no_custom_packed_export):
export_context_manager = brevitas_layer_export_mode
# generate an export_class with the handler declared above
export_class = block_quant_layer_level_manager(
export_handlers={LinearWeightBlockQuantHandlerFwd})
export_handlers=[LinearWeightBlockQuantHandlerFwd])

layers0 = [FirstVicunaLayer(layer) for layer in model.model.layers]
mlirs0 = compile_to_vmfb(
Expand Down

0 comments on commit e01f0d1

Please sign in to comment.