Skip to content

Commit

Permalink
PR comments
Browse files Browse the repository at this point in the history
Signed-off-by: luka <[email protected]>
  • Loading branch information
ProExpertProg committed Dec 4, 2024
1 parent 7901c70 commit 7046e4b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions tests/kernels/test_fused_quant_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
NUM_TOKENS = [1, 7, 83, 2048, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [1, 2, 3, 4, 16, 64, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
HIDDEN_SIZES = [1, 3, 4, 16, 64, 2048, 5120,
5137] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
ADD_RESIDUAL = [False, True]
SCALE_UBS = [True, False]
Expand Down
10 changes: 5 additions & 5 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
supports_moe_ops = False
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401

supports_moe_ops = True

# neuron has torch version that doesn't even have impl_abstract
Expand Down Expand Up @@ -242,6 +241,7 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
paged_kv_indptr: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
block_table_bound: torch.Tensor) -> None:

return torch.ops._C.advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables,
Expand Down Expand Up @@ -737,7 +737,7 @@ def scaled_fp8_quant(
shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz \
if current_platform.is_rocm() else torch.float8_e4m3fn
if current_platform.is_rocm() else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
Expand Down Expand Up @@ -1020,9 +1020,9 @@ def register_graph_buffers(fa: int, handles: List[List[int]],
# the case when users use `import __annotations__` to turn type
# hints into strings.
if isinstance(v, fn_type) \
and v.__code__.co_filename == __file__ \
and any(arg is torch.Tensor or arg == "torch.Tensor"
for arg in v.__annotations__.values()):
and v.__code__.co_filename == __file__ \
and any(arg is torch.Tensor or arg == "torch.Tensor"
for arg in v.__annotations__.values()):
names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
Expand Down

0 comments on commit 7046e4b

Please sign in to comment.