@@ -68,23 +68,34 @@ impl CoreBPE {
6868 fn _encode_bytes ( & self , py : Python , bytes : & [ u8 ] ) -> Vec < Rank > {
6969 py. allow_threads ( || {
7070 match std:: str:: from_utf8 ( bytes) {
71+ // Straightforward case
7172 Ok ( text) => self . encode_ordinary ( text) ,
73+ // Oops, don't actually have UTF-8. But we need to do the regex splitting in
74+ // Unicode space, so we make our best guess at where we would have splits
7275 Err ( e) => {
7376 let text = unsafe { std:: str:: from_utf8_unchecked ( & bytes[ ..e. valid_up_to ( ) ] ) } ;
7477 let ( tokens, last_piece_token_len) = self . encode ( text, & HashSet :: new ( ) ) ;
7578 let ( mut tokens, last_piece_token_len) =
7679 self . _increase_last_piece_token_len ( tokens, last_piece_token_len) ;
80+
81+ let mut unstable_bytes;
7782 if !tokens. is_empty ( ) && last_piece_token_len > 0 {
7883 // Lop off the tokens from the last piece and run BPE on the remaining bytes
79- // Somewhat niche, but this may not be correct if we'd have had a regex
80- // split between the valid UTF-8 and the invalid bytes, which is why this
81- // method is private
82- let mut unstable_bytes = self
84+ // This likely matches what models see better, e.g. if you assume we're
85+ // dealing with truncated UTF-8 bytes.
86+ // Niche, but note this may not be correct if we'd have had a regex
87+ // split between the valid UTF-8 and the invalid bytes.
88+ unstable_bytes = self
8389 . decode_bytes ( & tokens[ tokens. len ( ) - last_piece_token_len..] )
8490 . unwrap ( ) ;
8591 unstable_bytes. extend_from_slice ( & bytes[ e. valid_up_to ( ) ..] ) ;
8692
8793 tokens. truncate ( tokens. len ( ) - last_piece_token_len) ;
94+ } else {
95+ unstable_bytes = bytes[ e. valid_up_to ( ) ..] . to_vec ( ) ;
96+ }
97+
98+ if !unstable_bytes. is_empty ( ) {
8899 match self . encoder . get ( & unstable_bytes) {
89100 Some ( token) => tokens. push ( * token) ,
90101 None => {
0 commit comments