Skip to content

Commit

Permalink
chore: consolidate environment variables within one file (#882)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 12, 2024
1 parent ce6e3d6 commit 9019008
Show file tree
Hide file tree
Showing 36 changed files with 543 additions and 132 deletions.
10 changes: 3 additions & 7 deletions aphrodite/assets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Optional

from aphrodite import envs
from aphrodite.connections import global_http_connection


Expand All @@ -15,13 +16,8 @@ def get_default_cache_root():
)

vLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
APHRODITE_ASSETS_CACHE = os.path.expanduser(
os.getenv(
"APHRODITE_ASSETS_CACHE",
os.path.join(get_default_cache_root(), "aphrodite", "assets"),
))
APHRODITE_IMAGE_FETCH_TIMEOUT = int(os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT",
5))
APHRODITE_ASSETS_CACHE = envs.APHRODITE_ASSETS_CACHE
APHRODITE_IMAGE_FETCH_TIMEOUT = envs.APHRODITE_IMAGE_FETCH_TIMEOUT

def get_cache_dir() -> Path:
"""Get the path to the cache for storing downloaded assets."""
Expand Down
6 changes: 2 additions & 4 deletions aphrodite/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Attention layer ROCm GPUs."""
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
from loguru import logger

from aphrodite import envs
from aphrodite.attention.backends.abstract import (AttentionBackend,
AttentionImpl,
AttentionMetadata,
Expand Down Expand Up @@ -280,9 +280,7 @@ def __init__(

self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = (os.environ.get(
"APHRODITE_USE_TRITON_FLASH_ATTN", "True").lower()
in ("true", "1"))
self.use_triton_flash_attn = envs.APHRODITE_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
from aphrodite.attention.ops.triton_flash_attn import ( # noqa: F401
triton_attention)
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import torch
from loguru import logger

from aphrodite import envs
from aphrodite.attention.backends.abstract import AttentionBackend
from aphrodite.common.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip,
is_openvino, is_xpu)
from aphrodite.platforms import current_platform

APHRODITE_ATTENTION_BACKEND = os.getenv("APHRODITE_ATTENTION_BACKEND", None)
APHRODITE_ATTENTION_BACKEND = envs.APHRODITE_ATTENTION_BACKEND


class _Backend(enum.Enum):
Expand Down
50 changes: 34 additions & 16 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from loguru import logger
from transformers import PretrainedConfig

from aphrodite import envs
from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
cuda_device_count_stateless,
get_cpu_memory, is_cpu, is_hip, is_neuron,
Expand All @@ -30,8 +31,7 @@
BaseTokenizerGroup)

# If true, will load models from ModelScope instead of Hugging Face Hub.
APHRODITE_USE_MODELSCOPE = os.environ.get("APHRODITE_USE_MODELSCOPE",
"False").lower() == "true"
APHRODITE_USE_MODELSCOPE = envs.APHRODITE_USE_MODELSCOPE

_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768

Expand Down Expand Up @@ -1820,21 +1820,39 @@ def _get_and_verify_max_len(
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
if max_model_len is None:
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len and rope_scaling_arg is None:
raise ValueError(
f"User-specified max_model_len {max_model_len} is higher than "
f"the original {derived_max_model_len}. "
"Please provide a rope_scaling dict to scale the model.")
elif max_model_len > derived_max_model_len and rope_scaling_arg is not None:
# hope this works
logger.warning(
f"User-specified max_model_len {max_model_len} is higher than "
f"the original {derived_max_model_len}. "
"Attempting to use RoPE scaling with the provided rope_scaling "
"dict.")
derived_max_model_len = max_model_len
max_model_len = int(derived_max_model_len)
elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input
# with model_max_length and allow this override when it's smaller.
model_max_length = getattr(hf_config, "model_max_length", None)
if envs.APHRODITE_DYNAMIC_ROPE_SCALING:
scaling_factor = max_model_len / derived_max_model_len
hf_config.rope_scaling = {"factor": scaling_factor,
"type": "dynamic"}
logger.info(
"Using dynamic RoPE scaling to extend the model's max context "
f"length from {derived_max_model_len} to {max_model_len}.")
derived_max_model_len = max_model_len
elif model_max_length is not None and max_model_len <= model_max_length:
if disable_sliding_window:
# TODO: Find a model that has model_max_length
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue "
"so we can investigate.")
else:
raise ValueError(
f"User-specified max_model_len ({max_model_len}) is greater "
f"than the derived max_model_len ({max_len_key}="
f"{derived_max_model_len} or model_max_length="
f"{model_max_length} in model's config.json). To allow "
"greater lengths, please set the env var "
"APHRODITE_DYNAMIC_ROPE_SCALING=1")

return int(max_model_len)

Expand Down
5 changes: 3 additions & 2 deletions aphrodite/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from rich.progress import (BarColumn, MofNCompleteColumn, Progress,
TaskProgressColumn, TextColumn, TimeRemainingColumn)

from aphrodite import envs

RICH_CONSOLE = Console()
LOG_LEVEL = os.getenv("APHRODITE_LOG_LEVEL", "INFO").upper()

APHRODITE_CONFIGURE_LOGGING = int(os.getenv("APHRODITE_CONFIGURE_LOGGING",
"1"))
APHRODITE_CONFIGURE_LOGGING = envs.APHRODITE_CONFIGURE_LOGGING


def unwrap(wrapped, default=None):
Expand Down
6 changes: 3 additions & 3 deletions aphrodite/common/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Sampling parameters for text generation."""
import copy
import os
from enum import IntEnum
from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union
Expand All @@ -10,11 +9,12 @@
from loguru import logger
from typing_extensions import Annotated

from aphrodite import envs

_SAMPLING_EPS = 1e-5
_MAX_TEMP = 1e-2

APHRODITE_NO_DEPRECATION_WARNING = bool(
int(os.environ.get("APHRODITE_NO_DEPRECATION_WARNING", "0")))
APHRODITE_NO_DEPRECATION_WARNING = envs.APHRODITE_NO_DEPRECATION_WARNING


class SamplingType(IntEnum):
Expand Down
19 changes: 8 additions & 11 deletions aphrodite/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
SpinnerColumn, TextColumn, TimeElapsedColumn)
from typing_extensions import ParamSpec, TypeIs, assert_never

from aphrodite import envs
from aphrodite.common.logger import enable_trace_function_call
from aphrodite.distributed import get_tensor_model_parallel_rank

Expand Down Expand Up @@ -382,8 +383,7 @@ def get_aphrodite_instance_id():
Instance id represents an instance of the Aphrodite. All processes in the
same instance should have the same instance id.
"""
return os.environ.get("APHRODITE_INSTANCE_ID",
f"aphrodite-instance-{random_uuid()}")
return envs.APHRODITE_INSTANCE_ID or f"aphrodite-instance-{random_uuid()}"


@lru_cache(maxsize=None)
Expand Down Expand Up @@ -520,18 +520,15 @@ def get_distributed_init_method(ip: str, port: int) -> str:

def get_open_zmq_ipc_path() -> str:
if not in_windows():
APHRODITE_RPC_BASE_PATH = os.getenv("APHRODITE_RPC_BASE_PATH",
tempfile.gettempdir())
base_rpc_path = APHRODITE_RPC_BASE_PATH
base_rpc_path = envs.APHRODITE_RPC_BASE_PATH
return f"ipc://{base_rpc_path}/{uuid4()}"
else:
# windows doesn't support ipc://
# use tcp:// instead
return f"tcp://127.0.0.1:{get_open_port()}"

def get_open_port(port: Optional[int] = None) -> int:
port = int(os.getenv("APHRODITE_PORT", 0)
) if "APHRODITE_PORT" in os.environ else None
port = envs.APHRODITE_PORT
if port is not None:
while True:
try:
Expand Down Expand Up @@ -948,7 +945,7 @@ def find_library(lib_name: str) -> str:
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
# `LD_LIBRARY_PATH` searches the library in the user-defined paths
env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
env_ld_library_path = envs.LD_LIBRARY_PATH
if not locs and env_ld_library_path:
locs = [
os.path.join(dir, lib_name)
Expand All @@ -967,7 +964,7 @@ def find_nccl_library() -> str:
After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
found by `ctypes` automatically.
"""
so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
so_file = envs.APHRODITE_NCCL_SO_PATH

# manually load the nccl library
if so_file:
Expand All @@ -985,7 +982,7 @@ def find_nccl_library() -> str:


def enable_trace_function_call_for_thread() -> None:
if int(os.getenv("APHRODITE_TRACE_FUNCTION", "0")):
if envs.APHRODITE_TRACE_FUNCTION:
tmp_dir = tempfile.gettempdir()
filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
Expand Down Expand Up @@ -1074,7 +1071,7 @@ def cuda_device_count_stateless() -> int:
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.

return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES"))
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)


#From: https://stackoverflow.com/a/4104188/2749989
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from contextlib import contextmanager
from typing import Any, List, Optional, Union

Expand All @@ -8,6 +7,7 @@
from torch.distributed import ProcessGroup

from aphrodite import _custom_ops as ops
from aphrodite import envs
from aphrodite.common.utils import cuda_device_count_stateless
from aphrodite.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self,
assert isinstance(device, torch.device)
self.device = device

cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.multiprocessing as mp
from loguru import logger

from aphrodite import envs
from aphrodite.common.utils import (cuda_device_count_stateless,
update_environment_variables)
from aphrodite.distributed.device_communicators.cuda_wrapper import (
Expand Down Expand Up @@ -124,7 +125,7 @@ def can_actually_p2p(
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch).
""" # noqa
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs

Expand Down Expand Up @@ -183,13 +184,13 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
is_distributed = dist.is_initialized()

num_dev = cuda_device_count_stateless()
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
APHRODITE_CONFIG_ROOT = os.getenv("APHRODITE_CONFIG_ROOT", "~/.config")
path = os.path.expanduser(
f"{APHRODITE_CONFIG_ROOT}/aphrodite/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
)

path = os.path.join(
envs.APHRODITE_CACHE_ROOT,
f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
os.makedirs(os.path.dirname(path), exist_ok=True)
from aphrodite.distributed.parallel_state import get_world_group
if ((not is_distributed or get_world_group().local_rank == 0)
Expand Down
6 changes: 3 additions & 3 deletions aphrodite/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import pickle
import time
from contextlib import contextmanager
Expand All @@ -13,10 +12,11 @@
from torch.distributed import ProcessGroup
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore

from aphrodite import envs
from aphrodite.common.utils import get_ip, get_open_port

APHRODITE_RINGBUFFER_WARNING_INTERVAL = os.getenv(
"APHRODITE_RINGBUFFER_WARNING_INTERVAL", 60)
APHRODITE_RINGBUFFER_WARNING_INTERVAL = (
envs.APHRODITE_RINGBUFFER_WARNING_INTERVAL)

# time to wait if the queue is full or empty
# if we sleep for too short, it will consume too much CPU
Expand Down
5 changes: 3 additions & 2 deletions aphrodite/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
steps.
"""
import contextlib
import os
import pickle
import sys
from collections import namedtuple
Expand All @@ -36,6 +35,8 @@
from loguru import logger
from torch.distributed import Backend, ProcessGroup

from aphrodite import envs


@dataclass
class GraphCaptureContext:
Expand Down Expand Up @@ -866,7 +867,7 @@ def init_distributed_environment(
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if distributed_init_method == "env://":
local_rank = os.getenv("LOCAL_RANK", rank)
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
global _WORLD
Expand Down
5 changes: 3 additions & 2 deletions aphrodite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import os
from typing import Sequence, Tuple

import torch

APHRODITE_PP_LAYER_PARTITION = os.getenv("APHRODITE_PP_LAYER_PARTITION", None)
from aphrodite import envs

APHRODITE_PP_LAYER_PARTITION = envs.APHRODITE_PP_LAYER_PARTITION


def ensure_divisibility(numerator, denominator):
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/endpoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from loguru import logger
from starlette.routing import Mount

from aphrodite import envs
from aphrodite.common.config import ModelConfig
from aphrodite.common.outputs import RequestOutput
from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
Expand Down Expand Up @@ -635,7 +636,7 @@ async def validation_exception_handler(_, exc):
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)

if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
if token := envs.APHRODITE_API_KEY or args.api_keys:
admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key

if admin_key is None:
Expand Down
Loading

0 comments on commit 9019008

Please sign in to comment.