Skip to content

Commit

Permalink
Merge branch 'PygmalionAI:main' into amd-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Naomiusearch authored Dec 3, 2024
2 parents 531f7bc + 8b8d2ce commit 362cf85
Show file tree
Hide file tree
Showing 62 changed files with 6,433 additions and 385 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,6 @@ images/
*.exp
*.lib
*.obj

# generated files
**/generated/**
41 changes: 41 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
"kernels/quantization/gptq_marlin/awq_marlin_repack.cu"
"kernels/quantization/fp8/fp8_marlin.cu"
"kernels/all_reduce/custom_all_reduce.cu"
"kernels/permute_cols.cu"
"kernels/sampling/sampling.cu")

if(MSVC)
Expand Down Expand Up @@ -252,6 +253,46 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
endif()
endif()

#
# For the Machete kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/kernels/cutlass_extensions/:${CUTLASS_DIR}/python/:$PYTHONPATH
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/kernels/quantization/machete/generate.py
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
)

if (NOT machete_generation_result EQUAL 0)
message(FATAL_ERROR "Machete generation failed."
" Result: \"${machete_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
else()
message(STATUS "Machete generation completed successfully.")
endif()

# Add machete generated sources
file(GLOB MACHETE_GEN_SOURCES "kernels/quantization/machete/generated/*.cu")
list(APPEND APHRODITE_EXT_SRC ${MACHETE_GEN_SOURCES})
message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}")

# See comment above for scaled_mm_c3x (same if condition)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
${MACHETE_GEN_SOURCES}
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
endif()

# Add pytorch binding
list(APPEND APHRODITE_EXT_SRC
kernels/quantization/machete/machete_pytorch.cu)
endif()

define_gpu_extension_target(
Expand Down
29 changes: 29 additions & 0 deletions aphrodite/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,35 @@ def gptq_marlin_gemm(a: torch.Tensor,
is_zp_float)


# machete
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
return torch.ops._C.machete_supported_schedules(b_type)


def machete_gemm(
a: torch.Tensor,
b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B
b_type: ScalarType,
b_scales: Optional[torch.Tensor] = None,
b_zeros: Optional[torch.Tensor] = None,
b_group_size: Optional[int] = None,
c: Optional[torch.Tensor] = None,
alpha: Optional[float] = None,
beta: Optional[float] = None,
schedule: Optional[str] = None,
) -> torch.Tensor:
return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
b_group_size, c, alpha, beta, schedule)


def machete_prepack_B(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)

def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
return torch.ops._C.permute_cols(a, perm)


# fp8
def scaled_fp8_quant(
input: torch.Tensor,
Expand Down
11 changes: 9 additions & 2 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,15 @@ def _verify_quantization(self) -> None:
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override:
quant_method = quantization_override
self.quantization = quantization_override
if quantization_override == "awq_marlin":
quant_method = quant_method
logger.warning(
"awq_marlin kernels are temporarily disabled, "
"they will be re-enabled with a future release. "
"Falling back to AWQ kernels.")
else:
quant_method = quantization_override
self.quantization = quantization_override
break

# Verify quantization configurations.
Expand Down
85 changes: 84 additions & 1 deletion aphrodite/common/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,49 @@ class SamplingType(IntEnum):
RANDOM_SEED = 2
BEAM = 3

class SamplerID(IntEnum):
# Mirror these in aphrodite/modeling/layers/sampler.py
# Values out of order to keep backwards compatibility
# with Koboldcpp values
DRY = 7
PENALTIES = 6
NO_REPEAT_NGRAM = 8
TEMPERATURE = 5
TOP_NSIGMA = 9
TOP_P_TOP_K = 0
TOP_A = 1
MIN_P = 2
TFS = 3
ETA_CUTOFF = 10
EPSILON_CUTOFF = 11
TYPICAL_P = 4
QUADRATIC = 12
XTC = 13

@classmethod
def from_str(cls, value: Union[str, int]) -> "SamplerID":
"""Convert string or int to SamplerID enum.
Args:
value: String name (case-insensitive) or integer value
Returns:
SamplerID enum value
Raises:
ValueError: If value cannot be converted to SamplerID
"""
if isinstance(value, int):
return cls(value)

try:
return cls[value.upper()]
except KeyError as e:
valid_names = [x.name for x in cls]
raise ValueError(
f"Invalid sampler name '{value}'. Must be one of: {valid_names}"
) from e


LogitsProcessorFunc = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor],
Expand Down Expand Up @@ -173,8 +216,12 @@ class SamplingParams(
input into sections where repetition is evaluated separately.
Common examples are newlines, quotes, and other structural tokens.
Defaults to None.
dry_range: The range of tokens (input + output) to apply the DRY
sampler.
skew: Bias the token selection towards higher or lower probability
tokens. Defaults to 0 (disabled).
sampler_priority: A list of integers to control the order in which
samplers are applied.
"""

n: int = 1
Expand Down Expand Up @@ -226,7 +273,9 @@ class SamplingParams(
dry_base: float = 1.75
dry_allowed_length: int = 2
dry_sequence_breaker_ids: List[int] = []
dry_range: int = 0
skew: float = 0.0
sampler_priority: Optional[List[int]] = []
# The below fields are not supposed to be used as an input.
# They are set in post_init.
output_text_buffer_length: int = 0
Expand Down Expand Up @@ -266,7 +315,7 @@ class SamplingParams(
"logprobs": None,
"prompt_logprobs": None,
"detokenize": True,
"custom_token_bans": [],
"custom_token_bans": None,
"skip_special_tokens": True,
"spaces_between_special_tokens": True,
"include_stop_str_in_output": False,
Expand All @@ -278,7 +327,9 @@ class SamplingParams(
"dry_base": 1.75,
"dry_allowed_length": 2,
"dry_sequence_breaker_ids": [],
"dry_range": 0,
"skew": 0.0,
"sampler_priority": [],
}

def __post_init__(self) -> None:
Expand Down Expand Up @@ -424,10 +475,42 @@ def _verify_args(self) -> None:
raise ValueError(
"dry_allowed_length must be non-negative, got "
f"{self.dry_allowed_length}.")
if self.dry_range < 0:
raise ValueError(
"dry_range must be non-negative, got "
f"{self.dry_range}.")
if self.skew < 0.0:
raise ValueError(
"skew must be non-negative, got "
f"{self.skew}.")

if self.sampler_priority is not None:
if not self.sampler_priority:
self.sampler_priority = None
return

if not isinstance(self.sampler_priority, list):
raise ValueError(
"sampler_priority must be a list of integers or strings")

try:
self.sampler_priority = [
SamplerID.from_str(x) for x in self.sampler_priority
]
provided_samplers = set(self.sampler_priority)
except ValueError as e:
raise ValueError(
f"Invalid sampler ID in priority list: {e}"
) from e

required_samplers = set(SamplerID)
if not required_samplers.issubset(provided_samplers):
missing = required_samplers - provided_samplers
missing_names = [s.name for s in missing]
raise ValueError(
"Missing required samplers in priority list: "
f"{missing_names}"
)

def _verify_beam_search(self) -> None:
if self.best_of == 1:
Expand Down
23 changes: 22 additions & 1 deletion aphrodite/endpoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from typing import Any, Dict, List, Literal, Optional, Union

import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import (AliasChoices, BaseModel, ConfigDict, Field,
model_validator)
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated

Expand Down Expand Up @@ -154,12 +155,20 @@ class ChatCompletionRequest(OpenAIBaseModel):
dry_allowed_length: Optional[int] = 2
dry_sequence_breakers: Optional[List[str]] = Field(
default=["\n", ":", "\"", "*"])
dry_range: Optional[int] = Field(
default=0,
validation_alias=AliasChoices("dry_range",
"dry_penalty_last_n"))
dynatemp_min: Optional[float] = 0.0
dynatemp_max: Optional[float] = 0.0
dynatemp_exponent: Optional[float] = 1.0
nsigma: Optional[float] = 0.0
skew: Optional[float] = 0.0
custom_token_bans: Optional[List[int]] = None
sampler_priority: Optional[Union[List[int], List[str]]] = Field(
default=[],
validation_alias=AliasChoices("sampler_priority",
"sampler_order"))
# doc: end-chat-completion-sampling-params

# doc: begin-chat-completion-extra-params
Expand Down Expand Up @@ -311,12 +320,14 @@ def to_sampling_params(
dry_base=self.dry_base,
dry_allowed_length=self.dry_allowed_length,
dry_sequence_breaker_ids=dry_sequence_breaker_ids,
dry_range=self.dry_range,
dynatemp_min=self.dynatemp_min,
dynatemp_max=self.dynatemp_max,
dynatemp_exponent=self.dynatemp_exponent,
nsigma=self.nsigma,
skew=self.skew,
custom_token_bans=self.custom_token_bans,
sampler_priority=self.sampler_priority,
)

@model_validator(mode='before')
Expand Down Expand Up @@ -430,12 +441,20 @@ class CompletionRequest(OpenAIBaseModel):
dry_allowed_length: Optional[int] = 2
dry_sequence_breakers: Optional[List[str]] = Field(
default=["\n", ":", "\"", "*"])
dry_range: Optional[int] = Field(
default=0,
validation_alias=AliasChoices("dry_range",
"dry_penalty_last_n"))
dynatemp_min: Optional[float] = 0.0
dynatemp_max: Optional[float] = 0.0
dynatemp_exponent: Optional[float] = 1.0
nsigma: Optional[float] = 0.0
skew: Optional[float] = 0.0
custom_token_bans: Optional[List[int]] = None
sampler_priority: Optional[Union[List[int], List[str]]] = Field(
default=[],
validation_alias=AliasChoices("sampler_priority",
"sampler_order"))
# doc: end-completion-sampling-params

# doc: begin-completion-extra-params
Expand Down Expand Up @@ -546,12 +565,14 @@ def to_sampling_params(
dry_base=self.dry_base,
dry_allowed_length=self.dry_allowed_length,
dry_sequence_breaker_ids=dry_sequence_breaker_ids,
dry_range=self.dry_range,
dynatemp_min=self.dynatemp_min,
dynatemp_max=self.dynatemp_max,
dynatemp_exponent=self.dynatemp_exponent,
nsigma=self.nsigma,
skew=self.skew,
custom_token_bans=self.custom_token_bans,
sampler_priority=self.sampler_priority,
)

@model_validator(mode="before")
Expand Down
29 changes: 17 additions & 12 deletions aphrodite/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,30 @@ def _get_worker_kwargs(
or (rank % self.parallel_config.tensor_parallel_size == 0),
)

def _get_worker_module_and_class(self) -> Tuple[str, str]:
if self.scheduler_config.is_multi_step:
worker_module_name = "aphrodite.task_handler.multi_step_worker"
worker_class_name = "MultiStepWorker"
elif self.speculative_config:
worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "aphrodite.task_handler.worker"
worker_class_name = "Worker"
return (worker_module_name, worker_class_name)

def _get_create_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict:
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method)
if self.scheduler_config.is_multi_step:
worker_kwargs.update(
worker_module_name="aphrodite.task_handler.multi_step_worker",
worker_class_name="MultiStepWorker")
elif self.speculative_config:
worker_kwargs.update(
worker_module_name="aphrodite.spec_decode.spec_decode_worker",
worker_class_name="create_spec_worker")
else:
worker_kwargs.update(
worker_module_name="aphrodite.task_handler.worker",
worker_class_name="Worker")
(worker_module_name,
worker_class_name) = self._get_worker_module_and_class()
worker_kwargs.update(worker_module_name=worker_module_name,
worker_class_name=worker_class_name)

return worker_kwargs

def _create_worker(self,
Expand Down
Loading

0 comments on commit 362cf85

Please sign in to comment.