Skip to content

Commit

Permalink
Fixes imports order
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Oct 22, 2024
1 parent 80a84bb commit 0514119
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from vllm.config import ModelConfig
from vllm.model_executor.layers.pooler import (PoolingConfig, PoolingType)
from vllm.model_executor.layers.pooler import PoolingConfig, PoolingType


@pytest.mark.parametrize(("model_id", "expected_task"), [
Expand Down
8 changes: 4 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
Mapping, Optional, Set, Tuple, Type, Union)

import torch
from transformers import PretrainedConfig

import vllm.envs as envs
from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.pooler import PoolingConfig
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_pooling_config,
get_hf_text_config)
get_hf_text_config,
get_pooling_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_openvino, is_xpu, print_warning_once)

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import IntEnum
from dataclasses import dataclass
from enum import IntEnum

import torch
import torch.nn as nn
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
import torch
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from transformers import AutoModelForCausalLM, PretrainedConfig
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolingConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.pooler import PoolingConfig
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator)
Expand Down

0 comments on commit 0514119

Please sign in to comment.