@@ -19,15 +19,15 @@ def write_annotations(output_folder: Path, strict: bool) -> None:
19
19
None
20
20
"""
21
21
generate_template = _generate_type_safe_template if strict else _generate_union_template
22
- for dimensions in _DIMENSION_TYPES :
22
+ for dimensions , filename in _DIMENSIONS_TO_FILENAME . items () :
23
23
contents = "\n " .join (_annotate_type (dimensions , type_name , strict ) for type_name in _DATA_TYPES )
24
24
all_types = "\n " .join (
25
25
_indent (f"{ _quote (full_type_name )} ," ) for full_type_name in _list_all_types (dimensions , strict )
26
26
)
27
27
filename = output_folder / _DIMENSIONS_TO_FILENAME [dimensions ]
28
28
print (f"Writing { filename } .." )
29
29
with open (filename , "w" ) as f :
30
- f .write (generate_template (dimensions , contents , all_types ))
30
+ f .write (generate_template (contents , all_types ))
31
31
32
32
33
33
_DATA_TYPES : Final [dict [str , str ]] = {
@@ -52,13 +52,6 @@ def write_annotations(output_folder: Path, strict: bool) -> None:
52
52
"Timedelta64" : "np.timedelta64" ,
53
53
}
54
54
55
- _DIMENSION_TYPES : Final [dict [int , str ]] = {
56
- 0 : "tuple[int, ...]" ,
57
- 1 : "tuple[int, ...]" ,
58
- 2 : "tuple[int, ...]" ,
59
- 3 : "tuple[int, ...]" ,
60
- }
61
-
62
55
_DIMENSIONS_TO_PREFIX : Final [dict [int , str ]] = {
63
56
0 : "NDArray" ,
64
57
1 : "1DArray" ,
@@ -103,11 +96,11 @@ def _union_type(dimension_type: str, dtype: str) -> str:
103
96
104
97
105
98
def _annotate_type (dimensions : int , type_name : str , strict : bool ) -> str :
106
- dimension_type = _DIMENSION_TYPES [dimensions ]
107
99
type_with_prefix = _type_name_with_prefix (dimensions , type_name , strict )
108
100
data_type = _DATA_TYPES [type_name ]
109
101
110
102
dtype = "Any" if data_type == "None" else data_type
103
+ dimension_type = "tuple[int, ...]"
111
104
T = _strict_type (dimension_type , dtype ) if strict else _union_type (dimension_type , dtype )
112
105
dim = dimensions if dimensions > 0 else None
113
106
annotation = f"""{ type_with_prefix } : TypeAlias = Annotated[
@@ -118,7 +111,7 @@ def _annotate_type(dimensions: int, type_name: str, strict: bool) -> str:
118
111
return _unindent (annotation )
119
112
120
113
121
- def _generate_type_safe_template (dimensions : int , contents : str , all_types : str ) -> str :
114
+ def _generate_type_safe_template (contents : str , all_types : str ) -> str :
122
115
template = f"""from typing import Annotated, Any, TypeAlias
123
116
124
117
import numpy as np
@@ -134,7 +127,7 @@ def _generate_type_safe_template(dimensions: int, contents: str, all_types: str)
134
127
return template
135
128
136
129
137
- def _generate_union_template (dimensions : int , contents : str , all_types : str ) -> str :
130
+ def _generate_union_template (contents : str , all_types : str ) -> str :
138
131
template = f"""from typing import Annotated, Any, TypeAlias, Union
139
132
140
133
import numpy as np
0 commit comments