Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson committed Nov 15, 2024
1 parent 1993f3b commit 563f80c
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions benchmarks/kernels/benchmark_machete.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,10 @@ def to_torch_dtype(dt):
"int": torch.int,
"float": torch.float,
}[dt]

class ToTorchDtype(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, to_torch_dtype(values))

parser = FlexibleArgumentParser(
description="""
Expand All @@ -570,36 +574,34 @@ def to_torch_dtype(dt):
)
parser.add_argument(
"--act-type",
type=to_torch_dtype,
action=ToTorchDtype,
required=True,
help="Available options are "
"['bfloat16', 'float16', 'int8', 'float8_e4m3fn']",
choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'],
)
parser.add_argument(
"--group-scale-type",
type=to_torch_dtype,
help="Available options are ['bfloat16', 'float16']",
action=ToTorchDtype,
choices=['bfloat16', 'float16'],
)
parser.add_argument(
"--group-zero-type",
type=to_torch_dtype,
help="Available options are ['bfloat16', 'float16']",
choices=['bfloat16', 'float16'],
)
parser.add_argument(
"--channel-scale-type",
type=to_torch_dtype,
help="Available options are ['bfloat16', 'float16', 'float']",
action=ToTorchDtype,
choices=['float'],
)
parser.add_argument(
"--token-scale-type",
type=to_torch_dtype,
help="Available options are ['bfloat16', 'float16', 'float']",
action=ToTorchDtype,
choices=['float'],
)
parser.add_argument(
"--out-type",
type=to_torch_dtype,
help="Available options are "
"['bfloat16', 'float16', 'int', 'float']",
action=ToTorchDtype,
choices=['bfloat16', 'float16'],
)
parser.add_argument(
"--group-size",
Expand Down

0 comments on commit 563f80c

Please sign in to comment.