diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index bbbbc27bb..a2e08ec77 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -17,7 +17,7 @@ __all__ = [ 'Transform', - 'PerInputTrasform', + 'PerInputTransform', 'GraphTransform', 'PerInputModuleToModuleByHook', 'ModuleToModule', @@ -40,7 +40,7 @@ def apply(self, model: Module) -> Module: pass -class PerInputTrasform(ABC): +class PerInputTransform(ABC): @abstractmethod def apply(self, model: Module, inp: torch.Tensor) -> Module: @@ -66,7 +66,7 @@ def apply(self, graph_model: GraphModule) -> GraphModule: return graph_model -class PerInputModuleToModuleByHook(PerInputTrasform, ABC): +class PerInputModuleToModuleByHook(PerInputTransform, ABC): def __init__(self): self.input_size_map = {}