11use std:: collections:: HashSet ;
22
33use pyo3:: {
4- PyResult , exceptions,
4+ IntoPyObjectExt , PyResult , exceptions,
55 prelude:: * ,
66 pybacked:: PyBackedStr ,
7- types:: { PyBytes , PyList , PyTuple } ,
7+ types:: { PyBytes , PyList } ,
88} ;
99use rustc_hash:: FxHashMap as HashMap ;
1010
@@ -37,11 +37,14 @@ impl CoreBPE {
3737 py : Python ,
3838 text : & str ,
3939 allowed_special : HashSet < PyBackedStr > ,
40- ) -> Vec < Rank > {
40+ ) -> PyResult < Vec < Rank > > {
4141 py. allow_threads ( || {
4242 let allowed_special: HashSet < & str > =
4343 allowed_special. iter ( ) . map ( |s| s. as_ref ( ) ) . collect ( ) ;
44- self . encode ( text, & allowed_special) . 0
44+ match self . encode ( text, & allowed_special) {
45+ Ok ( ( tokens, _) ) => Ok ( tokens) ,
46+ Err ( e) => Err ( PyErr :: new :: < exceptions:: PyValueError , _ > ( e. message ) ) ,
47+ }
4548 } )
4649 }
4750
@@ -50,14 +53,20 @@ impl CoreBPE {
5053 py : Python ,
5154 text : & str ,
5255 allowed_special : HashSet < PyBackedStr > ,
53- ) -> Py < PyAny > {
54- let tokens = py. allow_threads ( || {
56+ ) -> PyResult < Py < PyAny > > {
57+ let tokens_res = py. allow_threads ( || {
5558 let allowed_special: HashSet < & str > =
5659 allowed_special. iter ( ) . map ( |s| s. as_ref ( ) ) . collect ( ) ;
57- self . encode ( text, & allowed_special) . 0
60+ self . encode ( text, & allowed_special)
5861 } ) ;
62+
63+ let tokens = match tokens_res {
64+ Ok ( ( tokens, _) ) => tokens,
65+ Err ( e) => return Err ( PyErr :: new :: < exceptions:: PyValueError , _ > ( e. message ) ) ,
66+ } ;
67+
5968 let buffer = TiktokenBuffer { tokens } ;
60- buffer. into_py ( py)
69+ buffer. into_py_any ( py)
6170 }
6271
6372 fn _encode_bytes ( & self , py : Python , bytes : & [ u8 ] ) -> Vec < Rank > {
@@ -69,7 +78,8 @@ impl CoreBPE {
6978 // Unicode space, so we make our best guess at where we would have splits
7079 Err ( e) => {
7180 let text = unsafe { std:: str:: from_utf8_unchecked ( & bytes[ ..e. valid_up_to ( ) ] ) } ;
72- let ( tokens, last_piece_token_len) = self . encode ( text, & HashSet :: new ( ) ) ;
81+ let ( tokens, last_piece_token_len) =
82+ self . encode ( text, & HashSet :: new ( ) ) . unwrap ( ) ;
7383 let ( mut tokens, last_piece_token_len) =
7484 self . _increase_last_piece_token_len ( tokens, last_piece_token_len) ;
7585
@@ -110,19 +120,14 @@ impl CoreBPE {
110120 py : Python ,
111121 text : & str ,
112122 allowed_special : HashSet < PyBackedStr > ,
113- ) -> Py < PyTuple > {
114- let ( tokens, completions) = py. allow_threads ( || {
123+ ) -> PyResult < ( Vec < Rank > , Py < PyList > ) > {
124+ let ( tokens, completions) : ( Vec < Rank > , HashSet < Vec < Rank > > ) = py. allow_threads ( || {
115125 let allowed_special: HashSet < & str > =
116126 allowed_special. iter ( ) . map ( |s| s. as_ref ( ) ) . collect ( ) ;
117127 self . _encode_unstable_native ( text, & allowed_special)
118128 } ) ;
119- let py_completions = PyList :: new_bound (
120- py,
121- completions
122- . iter ( )
123- . map ( |seq| PyList :: new_bound ( py, & seq[ ..] ) ) ,
124- ) ;
125- ( tokens, py_completions) . into_py ( py)
129+ let py_completions = PyList :: new ( py, completions. into_iter ( ) ) ?;
130+ Ok ( ( tokens, py_completions. into ( ) ) )
126131 }
127132
128133 fn encode_single_token ( & self , piece : & [ u8 ] ) -> PyResult < Rank > {
@@ -151,17 +156,17 @@ impl CoreBPE {
151156 #[ pyo3( name = "decode_bytes" ) ]
152157 fn py_decode_bytes ( & self , py : Python , tokens : Vec < Rank > ) -> Result < Py < PyBytes > , PyErr > {
153158 match py. allow_threads ( || self . decode_bytes ( & tokens) ) {
154- Ok ( bytes) => Ok ( PyBytes :: new_bound ( py, & bytes) . into ( ) ) ,
159+ Ok ( bytes) => Ok ( PyBytes :: new ( py, & bytes) . into ( ) ) ,
155160 Err ( e) => Err ( pyo3:: exceptions:: PyKeyError :: new_err ( format ! ( "{}" , e) ) ) ,
156161 }
157162 }
158163
159164 fn decode_single_token_bytes ( & self , py : Python , token : Rank ) -> PyResult < Py < PyBytes > > {
160165 if let Some ( bytes) = self . decoder . get ( & token) {
161- return Ok ( PyBytes :: new_bound ( py, bytes) . into ( ) ) ;
166+ return Ok ( PyBytes :: new ( py, bytes) . into ( ) ) ;
162167 }
163168 if let Some ( bytes) = self . special_tokens_decoder . get ( & token) {
164- return Ok ( PyBytes :: new_bound ( py, bytes) . into ( ) ) ;
169+ return Ok ( PyBytes :: new ( py, bytes) . into ( ) ) ;
165170 }
166171 Err ( PyErr :: new :: < exceptions:: PyKeyError , _ > ( token. to_string ( ) ) )
167172 }
0 commit comments