@@ -17,26 +17,32 @@ def model_benchmark_shapes(args):
1717 N = config ["intermediate_size" ]
1818 K = config ["hidden_size" ]
1919
20- shapes .append ((M , N , K ))
20+ shapes .append ((M , N , K , 'TN' ))
2121
2222 return shapes
2323
2424
2525def get_x_vals ():
2626 x_vals = [
27- (1 , 1280 , 8192 ),
28- (32 , 1280 , 8192 ),
29- (64 , 1280 , 8192 ),
30- (128 , 1280 , 8192 ),
31- (192 , 1280 , 8192 ),
32- (256 , 1280 , 8192 ),
33- (320 , 1280 , 8192 ),
34- (512 , 1280 , 8192 ),
35- (1024 , 1280 , 8192 ),
36- (2048 , 1280 , 8192 ),
37- (4096 , 1280 , 8192 ),
38- (8192 , 1280 , 8192 ),
39- (16384 , 1280 , 8192 ),
27+ (1 , 1280 , 8192 , 'TN' ),
28+ (32 , 1280 , 8192 , 'TN' ),
29+ (64 , 1280 , 8192 , 'TN' ),
30+ (128 , 1280 , 8192 , 'TN' ),
31+ (192 , 1280 , 8192 , 'TN' ),
32+ (256 , 1280 , 8192 , 'TN' ),
33+ (320 , 1280 , 8192 , 'TN' ),
34+ (512 , 1280 , 8192 , 'TN' ),
35+ (1024 , 1280 , 8192 , 'TN' ),
36+ (2048 , 1280 , 8192 , 'TN' ),
37+ (4096 , 1280 , 8192 , 'TN' ),
38+ (8192 , 1280 , 8192 , 'TN' ),
39+ (16384 , 1280 , 8192 , 'TN' ),
40+ (8192 , 7168 , 20480 , 'NT' ),
41+ (1024 , 20480 , 8192 , 'NT' ),
42+ (1024 , 8192 , 20480 , 'NT' ),
43+ (8192 , 7168 , 20480 , 'TN' ),
44+ (1024 , 20480 , 8192 , 'TN' ),
45+ (1024 , 8192 , 20480 , 'TN' ),
4046 ]
4147 return x_vals
4248
@@ -45,11 +51,11 @@ def run_benchmark(args):
4551 assert not (args .shape and args .model ) or not (args .shape and args .M ), \
4652 "User can specify --shape or --model MODEL -M VAL exclusively"
4753
48- x_names = ['M' , 'N' , 'K' ]
54+ x_names = ['M' , 'N' , 'K' , 'layout' ]
4955 if args .model :
5056 x_vals_list = model_benchmark_shapes (args )
5157 elif args .shape :
52- x_vals_list = [args .shape ]
58+ x_vals_list = [args .shape + [ args . layout ] ]
5359 else :
5460 x_vals_list = get_x_vals ()
5561
@@ -71,10 +77,10 @@ def run_benchmark(args):
7177 ylabel = ylabel , plot_name = f'GEMM A16W16 Benchmark' , args = {"metric" : args .metric })
7278
7379 @triton .testing .perf_report ([benchmark ])
74- def bench_gemm_a16w16 (M , N , K , metric , provider ):
80+ def bench_gemm_a16w16 (M , N , K , layout , metric , provider ):
7581 # NOTE: Assume bias and output has the same dtype
7682 c_dtype = torch .bfloat16
77- x , w = generate_gemm_a16w16_inputs (M , N , K , c_dtype )
83+ x , w = generate_gemm_a16w16_inputs (M , N , K , c_dtype , layout )
7884 # flops
7985 flops = 2.0 * M * N * K
8086 # memory transfer
@@ -119,6 +125,8 @@ def parse_args():
119125 help = "user-defined shape to benchmark" )
120126 parser .add_argument ("--metric" , type = str , choices = ["time" , "throughput" , "bandwidth" ],
121127 default = "throughput" , help = "metric to plot" )
128+ parser .add_argument ("--layout" , type = str , choices = ["TT" , "TN" , "NT" , "NN" ],
129+ default = "TN" , help = "Layout of input and weight matrix" )
122130 args = parser .parse_args ()
123131 return args
124132
0 commit comments