diff --git a/README.md b/README.md index 748578b6..661093fe 100644 --- a/README.md +++ b/README.md @@ -128,3 +128,9 @@ setup( Then simply `pip install ./my_tiktoken_extension` and you should be able to use your custom encodings! Make sure **not** to use an editable install. + +**Hosting your own encodings for enterprise usage.** + +For most use cases, the default public OpenAI encodings are enabled by default and no changes are needed. However, for organizations operating in an enterprise setting, existing network configurations may necessitate hosting encodings internally. + +To change the host which is serving encodings files for populating the plugin modules, simply set the `ENCODINGS_HOST` environmental variable. The default is the public OpenAI hosted file server. Enterprises hosting their own encodings can see which encodings and files are supported and the routing involved by viewing the [source directly](./tiktoken_ext/openai_public.py). diff --git a/tiktoken/load.py b/tiktoken/load.py index cc0a6a6d..3c6d295b 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -32,7 +32,11 @@ def check_hash(data: bytes, expected_hash: str) -> bool: return actual_hash == expected_hash -def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> bytes: +def read_file_cached( + blobpath: str, + expected_hash: Optional[str] = None, + is_self_hosting: Optional[bool] = False +) -> bytes: user_specified_cache = True if "TIKTOKEN_CACHE_DIR" in os.environ: cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] @@ -52,9 +56,20 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> byte if os.path.exists(cache_path): with open(cache_path, "rb") as f: data = f.read() - if expected_hash is None or check_hash(data, expected_hash): + if expected_hash is None: return data + if check_hash(data, expected_hash): + return data + + if is_self_hosting: + raise ValueError( + f"Hash mismatch for data from {blobpath} (expected {expected_hash}). " + f"This may indicate change in the `tiktoken` encodings for this version. " + f"Please update the hosted encodings or remove/unset the `ENCODINGS_HOST` " + "to attempt to refresh the cache from the central host (`openaipublic`)." + ) + # the cached file does not match the hash, remove it and re-fetch try: os.remove(cache_path) @@ -83,10 +98,8 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> byte def data_gym_to_mergeable_bpe_ranks( - vocab_bpe_file: str, - encoder_json_file: str, - vocab_bpe_hash: Optional[str] = None, - encoder_json_hash: Optional[str] = None, + vocab_bpe_contents: str, + encoder_json_contents: str, ) -> dict[bytes, int]: # NB: do not add caching to this function rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] @@ -101,7 +114,6 @@ def data_gym_to_mergeable_bpe_ranks( assert len(rank_to_intbyte) == 2**8 # vocab_bpe contains the merges along with associated ranks - vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode() bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]] def decode_data_gym(value: str) -> bytes: @@ -118,7 +130,7 @@ def decode_data_gym(value: str) -> bytes: # check that the encoder file matches the merges file # this sanity check is important since tiktoken assumes that ranks are ordered the same # as merge priority - encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash)) + encoder_json = json.loads(encoder_json_contents) encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()} # drop these two special tokens if present, since they're not mergeable bpe tokens encoder_json_loaded.pop(b"<|endoftext|>", None) @@ -141,10 +153,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No def load_tiktoken_bpe( - tiktoken_bpe_file: str, expected_hash: Optional[str] = None + contents:bytes ) -> dict[bytes, int]: # NB: do not add caching to this function - contents = read_file_cached(tiktoken_bpe_file, expected_hash) return { base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py index 330ecabb..9f557e0d 100644 --- a/tiktoken_ext/openai_public.py +++ b/tiktoken_ext/openai_public.py @@ -1,4 +1,5 @@ -from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe +import os +from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe, read_file_cached ENDOFTEXT = "<|endoftext|>" FIM_PREFIX = "<|fim_prefix|>" @@ -6,13 +7,40 @@ FIM_SUFFIX = "<|fim_suffix|>" ENDOFPROMPT = "<|endofprompt|>" +ENCODINGS_HOST = os.getenv("ENCODINGS_HOST", None) + +if "ENCODINGS_HOST" in os.environ: + ENCODINGS_HOST = os.environ["ENCODINGS_HOST"] + IS_HOSTING_ENCODINGS = True +else: + ENCODINGS_HOST = "https://openaipublic.blob.core.windows.net" + IS_HOSTING_ENCODINGS = False + +VOCAB_BPE_FILE = f"{ENCODINGS_HOST}/gpt-2/encodings/main/vocab.bpe" +VOCAB_BPE_HASH = "1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5" +ENCODER_JSON_FILE = f"{ENCODINGS_HOST}/gpt-2/encodings/main/encoder.json" +ENCODER_JSON_HASH = "196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783" +R50K_BASE_FILE = f"{ENCODINGS_HOST}/encodings/r50k_base.tiktoken" +R50K_BASE_HASH = "306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930" +P50K_BASE_FILE = f"{ENCODINGS_HOST}/encodings/p50k_base.tiktoken" +P50K_BASE_HASH = "94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069" +CL100K_BASE_FILE = f"{ENCODINGS_HOST}/encodings/cl100k_base.tiktoken" +CL100K_BASE_HASH = "223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7" def gpt2(): + vocab_bpe_contents = read_file_cached( + VOCAB_BPE_FILE, + VOCAB_BPE_HASH, + IS_HOSTING_ENCODINGS + ).decode() + encoder_json_contents = read_file_cached( + ENCODER_JSON_FILE, + ENCODER_JSON_HASH, + IS_HOSTING_ENCODINGS + ) mergeable_ranks = data_gym_to_mergeable_bpe_ranks( - vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe", - encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json", - vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5", - encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783", + vocab_bpe_contents= vocab_bpe_contents, + encoder_json_contents=encoder_json_contents ) return { "name": "gpt2", @@ -27,10 +55,8 @@ def gpt2(): def r50k_base(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken", - expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930", - ) + contents = read_file_cached(R50K_BASE_FILE, R50K_BASE_HASH, IS_HOSTING_ENCODINGS) + mergeable_ranks = load_tiktoken_bpe(contents) return { "name": "r50k_base", "explicit_n_vocab": 50257, @@ -41,10 +67,8 @@ def r50k_base(): def p50k_base(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", - expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", - ) + contents = read_file_cached(P50K_BASE_FILE, P50K_BASE_HASH, IS_HOSTING_ENCODINGS) + mergeable_ranks = load_tiktoken_bpe(contents) return { "name": "p50k_base", "explicit_n_vocab": 50281, @@ -55,10 +79,8 @@ def p50k_base(): def p50k_edit(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", - expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", - ) + contents = read_file_cached(P50K_BASE_FILE, P50K_BASE_HASH, IS_HOSTING_ENCODINGS) + mergeable_ranks = load_tiktoken_bpe(contents) special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} return { "name": "p50k_edit", @@ -69,10 +91,8 @@ def p50k_edit(): def cl100k_base(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken", - expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7", - ) + contents = read_file_cached(CL100K_BASE_FILE, CL100K_BASE_HASH, IS_HOSTING_ENCODINGS) + mergeable_ranks = load_tiktoken_bpe(contents) special_tokens = { ENDOFTEXT: 100257, FIM_PREFIX: 100258,