6
6
7
7
import logging
8
8
import os
9
+ import sys
9
10
from collections import OrderedDict
10
11
from typing import Any , Literal , NamedTuple , TypeVar , Union
11
12
15
16
from .quants import quant_shape_to_byte_shape
16
17
17
18
if __name__ == "__main__" :
18
- import sys
19
19
from pathlib import Path
20
20
21
21
# Allow running file in package as a script.
28
28
GGUF_VERSION ,
29
29
GGMLQuantizationType ,
30
30
GGUFValueType ,
31
+ GGUFEndian ,
31
32
)
32
33
33
34
logger = logging .getLogger (__name__ )
@@ -53,6 +54,48 @@ class ReaderField(NamedTuple):
53
54
54
55
types : list [GGUFValueType ] = []
55
56
57
+ def contents (self , index_or_slice : int | slice = slice (None )) -> Any :
58
+ if self .types :
59
+ to_string = lambda x : str (x .tobytes (), encoding = 'utf-8' ) # noqa: E731
60
+ main_type = self .types [0 ]
61
+
62
+ if main_type == GGUFValueType .ARRAY :
63
+ sub_type = self .types [- 1 ]
64
+
65
+ if sub_type == GGUFValueType .STRING :
66
+ indices = self .data [index_or_slice ]
67
+
68
+ if isinstance (index_or_slice , int ):
69
+ return to_string (self .parts [indices ]) # type: ignore
70
+ else :
71
+ return [to_string (self .parts [idx ]) for idx in indices ] # type: ignore
72
+ else :
73
+ # FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too
74
+
75
+ # Check if it's unsafe to perform slice optimization on data
76
+ # if any(True for idx in self.data if len(self.parts[idx]) != 1):
77
+ # optim_slice = slice(None)
78
+ # else:
79
+ # optim_slice = index_or_slice
80
+ # index_or_slice = slice(None)
81
+
82
+ # if isinstance(optim_slice, int):
83
+ # return self.parts[self.data[optim_slice]].tolist()[0]
84
+ # else:
85
+ # return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice]
86
+
87
+ if isinstance (index_or_slice , int ):
88
+ return self .parts [self .data [index_or_slice ]].tolist ()[0 ]
89
+ else :
90
+ return [pv for idx in self .data [index_or_slice ] for pv in self .parts [idx ].tolist ()]
91
+
92
+ if main_type == GGUFValueType .STRING :
93
+ return to_string (self .parts [- 1 ])
94
+ else :
95
+ return self .parts [- 1 ].tolist ()[0 ]
96
+
97
+ return None
98
+
56
99
57
100
class ReaderTensor (NamedTuple ):
58
101
name : str
@@ -101,10 +144,19 @@ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] =
101
144
# If we get 0 here that means it's (probably) a GGUF file created for
102
145
# the opposite byte order of the machine this script is running on.
103
146
self .byte_order = 'S'
104
- temp_version = temp_version .newbyteorder (self .byte_order )
147
+ temp_version = temp_version .view ( temp_version . dtype . newbyteorder (self .byte_order ) )
105
148
version = temp_version [0 ]
106
149
if version not in READER_SUPPORTED_VERSIONS :
107
150
raise ValueError (f'Sorry, file appears to be version { version } which we cannot handle' )
151
+ if sys .byteorder == "little" :
152
+ # Host is little endian
153
+ host_endian = GGUFEndian .LITTLE
154
+ swapped_endian = GGUFEndian .BIG
155
+ else :
156
+ # Sorry PDP or other weird systems that don't use BE or LE.
157
+ host_endian = GGUFEndian .BIG
158
+ swapped_endian = GGUFEndian .LITTLE
159
+ self .endianess = swapped_endian if self .byte_order == "S" else host_endian
108
160
self .fields : OrderedDict [str , ReaderField ] = OrderedDict ()
109
161
self .tensors : list [ReaderTensor ] = []
110
162
offs += self ._push_field (ReaderField (offs , 'GGUF.version' , [temp_version ], [0 ], [GGUFValueType .UINT32 ]))
@@ -146,11 +198,7 @@ def _get(
146
198
itemsize = int (np .empty ([], dtype = dtype ).itemsize )
147
199
end_offs = offset + itemsize * count
148
200
arr = self .data [offset :end_offs ].view (dtype = dtype )[:count ]
149
- if override_order is not None :
150
- return arr .view (arr .dtype .newbyteorder (override_order ))
151
- if self .byte_order == 'S' :
152
- return arr .view (arr .dtype .newbyteorder (self .byte_order ))
153
- return arr
201
+ return arr .view (arr .dtype .newbyteorder (self .byte_order if override_order is None else override_order ))
154
202
155
203
def _push_field (self , field : ReaderField , skip_sum : bool = False ) -> int :
156
204
if field .name in self .fields :
@@ -192,6 +240,7 @@ def _get_field_parts(
192
240
offs += int (alen .nbytes )
193
241
aparts : list [npt .NDArray [Any ]] = [raw_itype , alen ]
194
242
data_idxs : list [int ] = []
243
+ # FIXME: Handle multi-dimensional arrays properly instead of flattening
195
244
for idx in range (alen [0 ]):
196
245
curr_size , curr_parts , curr_idxs , curr_types = self ._get_field_parts (offs , raw_itype [0 ])
197
246
if idx == 0 :
0 commit comments