Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
RattataKing committed Aug 26, 2024
1 parent eaf12fe commit 9635551
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions sharktank/sharktank/tools/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,8 @@ def supports(self, op_name: str) -> bool:
return "matmul_transpose_b" in op_name

def get_shapes(self, template: list[str]) -> ProblemSize:
mmt_re = None
dps = None
for line in template:
if "linalg.generic" not in line:
continue
Expand Down Expand Up @@ -585,8 +587,8 @@ def get_shapes(self, template: list[str]) -> ProblemSize:
res_type=res_shaped_type,
dispatch_kind=DispatchKind.mmt,
)

assert False, "Shape not found"
assert mmt_re
assert dps, f"'{mmt_re}' not found in given context"

def get_transform_function_mmt(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
Expand Down Expand Up @@ -1271,13 +1273,13 @@ def walk_mlir_op(


def tune(
input: str,
output: str = "",
limit: int = 4096,
num_subgroups: int = 4,
lhs_dims: str = "mk",
rhs_dims: str = "nk",
tile_dims: str = "mnk",
input: str, # Path to the mlir file to be tuned
output: str = "", # Path to the output directory, auto creates one if not given
limit: int = 4096, # Max candidates to be generated
num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints
lhs_dims: str = "mk", # Dimensions for the left-hand side operand in matrix operations
rhs_dims: str = "nk", # Dimensions for the right-hand side operand in matrix operations
tile_dims: str = "mnk", # Dimensions for the tile size
):
input_file = str(input)

Expand Down

0 comments on commit 9635551

Please sign in to comment.