From 4a632c164831295ac4a2aa1202712fa6aaa72342 Mon Sep 17 00:00:00 2001 From: Tim Nunamaker Date: Tue, 2 Apr 2024 15:46:22 -0500 Subject: [PATCH] Use a writeable directory for nltk --- selfie/__init__.py | 29 +++++-------- selfie/utils/filesystem.py | 86 ++++++++++++++++++++++++++++---------- 2 files changed, 76 insertions(+), 39 deletions(-) diff --git a/selfie/__init__.py b/selfie/__init__.py index d204c56..500369c 100644 --- a/selfie/__init__.py +++ b/selfie/__init__.py @@ -1,27 +1,20 @@ +import os + import selfie.logging import logging import warnings import colorlog -### Preemptive fix based on suggestion in https://github.com/BerriAI/litellm/issues/2607 -# import platform -# import os -# -# os_name = platform.system() -# -# if os_name == 'Darwin': # macOS -# cache_dir = os.path.expanduser('~/Library/Caches/TikToken') -# elif os_name == 'Windows': -# cache_dir = os.path.join(os.environ['APPDATA'], 'TikToken', 'Cache') -# else: # Assume Linux/Unix -# cache_dir = os.path.expanduser('~/TikToken/Cache') -# -# # LiteLLM writes to a read-only directory in the built application bundle, try to override it -# # Source: https://github.com/BerriAI/litellm/pull/1947, with the latest code here: https://github.com/BerriAI/litellm/blob/main/litellm/utils.py -# os.environ['TIKTOKEN_CACHE_DIR'] = cache_dir -# -# # Now we can safely import litellm +from selfie.utils.filesystem import get_nltk_dir, get_tiktoken_dir + +# Override the data dir for nltk as LlamaIndex chooses a write-protected directory (https://github.com/run-llama/llama_index/blob/v0.10.26/llama-index-core/llama_index/core/utils.py#L47) +# https://github.com/nltk/nltk/blob/3.8.1/web/data.rst +os.environ["NLTK_DATA"] = get_nltk_dir("Selfie") +os.environ["TIKTOKEN_CACHE_DIR"] = get_tiktoken_dir("Selfie") +# LiteLLM doesn't (yet?) respect the TIKTOK_CACHE_DIR environment variable +# Source: https://github.com/BerriAI/litellm/pull/1947, with the latest code here: https://github.com/BerriAI/litellm/blob/main/litellm/utils.py +# Now we can safely import litellm import litellm # Suppress specific warnings diff --git a/selfie/utils/filesystem.py b/selfie/utils/filesystem.py index 17ec811..96b98c7 100644 --- a/selfie/utils/filesystem.py +++ b/selfie/utils/filesystem.py @@ -2,36 +2,81 @@ import platform -def get_app_dir(app_name, dir_name, roaming=True, log_dir=False): +def get_system_path(app_name, dir_name, path_type='data'): + """ + Generates paths for app data, caches, and logs based on the operating system. + + Args: + app_name (str): The application's name, part of the path. + dir_name (str): The specific directory name for the data. + path_type (str): The type of path ('data', 'cache', 'logs'). + + Returns: + str: The constructed path. + """ os_name = platform.system() - if os_name == 'Darwin': - home = os.path.expanduser('~') - if log_dir: - return os.path.join(home, 'Library', 'Logs', app_name, dir_name) - return os.path.join(home, 'Library', 'Application Support', app_name, dir_name) - elif os_name == 'Windows': - if roaming: - root = os.environ.get('APPDATA') - else: - root = os.environ.get('LOCALAPPDATA') - if root is None: - raise OSError("Unable to determine application data directory") - return os.path.join(root, app_name, dir_name) - else: - home = os.path.expanduser('~') - return os.path.join(home, '.' + app_name, dir_name) + paths_config = { + 'Darwin': { + 'base': os.path.expanduser('~'), + 'sub': { + 'data': ['Library', 'Application Support'], + 'cache': ['Library', 'Caches'], + 'logs': ['Library', 'Logs'], + } + }, + 'Windows': { + 'base': { + 'data': os.environ.get('APPDATA'), + 'cache': os.environ.get('LOCALAPPDATA'), + 'logs': os.environ.get('LOCALAPPDATA'), + }, + 'sub': { + 'data': [], + 'cache': ['Cache'], + 'logs': ['Logs'], + } + }, + 'Linux': { + 'base': os.path.expanduser('~'), + 'sub': { + 'data': ['.' + app_name], + 'cache': ['.' + app_name, 'cache'], + 'logs': [app_name, 'logs'], + } + } + } + + config = paths_config.get(os_name, paths_config['Linux']) # Default to Linux for unknown OS + base_path = config['base'][path_type] if isinstance(config['base'], dict) else config['base'] + sub_path = config['sub'][path_type] + [app_name, dir_name] if path_type in ['data', 'logs'] else config['sub'][path_type] + [dir_name] + + if base_path is None: + raise OSError(f"Unable to determine base path for {path_type} on {os_name}.") + + constructed_path = os.path.join(base_path, *sub_path) + normalized_path = os.path.normpath(constructed_path) + + return normalized_path def ensure_dir_exists(dir_path): os.makedirs(dir_path, exist_ok=True) +def get_nltk_dir(app_name): + return get_system_path(app_name, "nltk_data", path_type='cache') + + +def get_tiktoken_dir(app_name): + return get_system_path(app_name, "tiktoken_cache", path_type='cache') + + def get_data_dir(app_name): - return get_app_dir(app_name, 'data', roaming=True) + return get_system_path(app_name, 'data') def get_log_dir(app_name): - return get_app_dir(app_name, '', log_dir=True) + return get_system_path(app_name, '', path_type='logs') def get_data_path(app_name, file_name): @@ -45,8 +90,7 @@ def get_log_path(app_name, file_name): ensure_dir_exists(log_dir) return os.path.join(log_dir, file_name) + def resolve_path(path): """Expand user directory (~), resolve to absolute path, and follow symlinks.""" return os.path.realpath(os.path.abspath(os.path.expanduser(path))) - -