1414import argparse
1515import bisect
1616from pathlib import Path
17- from typing import Dict , List
1817
1918def main ():
2019 parser = argparse .ArgumentParser (
@@ -214,7 +213,7 @@ def get_op_pattern(base_op_name: str, get_backward: bool) -> tuple:
214213 else :
215214 return (base_op_name , f"{ base_op_name } " )
216215
217- def process_l1_loss (content : str , case_name : str , data : List , columns : List ):
216+ def process_l1_loss (content : str , case_name : str , data : list , columns : list ):
218217 shape_matches = list (re .finditer (r"(shape\s*[:=].*?)(?=\n\S|$)" , content ))
219218 shape_lines = [match .group (0 ) for match in shape_matches ]
220219 shape_positions = [match .start () for match in shape_matches ]
@@ -281,7 +280,7 @@ def process_l1_loss(content: str, case_name: str, data: List, columns: List):
281280
282281 data .append ([record .get (col , "" ) for col in columns ])
283282
284- def extract_times (content : str , pattern : str , get_backward : bool ) -> List :
283+ def extract_times (content : str , pattern : str , get_backward : bool ) -> list :
285284 lines = content .split ('\n ' )
286285 results = []
287286 for line in lines :
@@ -297,8 +296,8 @@ def extract_times(content: str, pattern: str, get_backward: bool) -> List:
297296
298297 return results
299298
300- def create_record (params : Dict , case_name : str , op_name : str ,
301- backward : str , time_us : float ) -> Dict :
299+ def create_record (params : dict , case_name : str , op_name : str ,
300+ backward : str , time_us : float ) -> dict :
302301 return {
303302 "P" : params .get ("p" , "" ),
304303 ** params ,
@@ -316,7 +315,7 @@ def convert_to_us(value: float, unit: str) -> float:
316315 return value * 1_000_000
317316 return value
318317
319- def extract_params (text : str ) -> Dict :
318+ def extract_params (text : str ) -> dict :
320319 params = {}
321320 pairs = re .split (r'[;]' , text .strip ())
322321
0 commit comments