Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Oct 18, 2024
1 parent 401f1ae commit 1d9f5d4
Show file tree
Hide file tree
Showing 27 changed files with 119 additions and 217 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen

import dropout_layer_norm
import rotary_emb
import torch
Expand All @@ -32,6 +30,7 @@
from lorax_server.adapters.weights import AdapterBatchData
from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
FastLayerNorm,
MultiAdapterHead,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from typing import Any, List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen
import numpy as np
import torch
import torch.distributed
Expand All @@ -27,6 +26,7 @@
from lorax_server.adapters.weights import AdapterBatchData
from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
FastLayerNorm,
FastLinear,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen
import torch
import torch.distributed
from torch import nn
Expand All @@ -40,6 +39,7 @@
from lorax_server.layers.rotary import PositionRotaryEmbedding
from lorax_server.layers.tensor_parallel import TensorParallelHead
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import MultiAdapterHead, TensorParallelAdapterRowLinear, TensorParallelMultiAdapterLinear
from lorax_server.utils.lora import (
DOWN_PROJ,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen
import torch
import torch.distributed
from torch import nn
Expand All @@ -26,6 +25,7 @@
from lorax_server.adapters import AdapterBatchData
from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
PositionRotaryEmbedding,
TensorParallelAdapterRowLinear,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen
import torch
import torch.distributed
from torch import nn
Expand All @@ -31,6 +30,7 @@
from lorax_server.adapters import AdapterBatchData
from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
FastLayerNorm,
TensorParallelAdapterRowLinear,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen

# Flash attention imports
import dropout_layer_norm
import torch
Expand All @@ -32,6 +30,7 @@

from lorax_server.adapters import AdapterBatchData
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
MultiAdapterHead,
PositionRotaryEmbedding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen

# Flash attention imports
import dropout_layer_norm
import torch
Expand All @@ -32,6 +30,7 @@

from lorax_server.adapters import AdapterBatchData
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
from lorax_server.utils.layers import (
MultiAdapterHead,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen

# Flash attention imports
import dropout_layer_norm
import numpy as np
Expand All @@ -35,6 +33,7 @@
from lorax_server.adapters import AdapterBatchData
from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
from lorax_server.utils.layers import (
FastLinear,
Expand Down Expand Up @@ -189,13 +188,15 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(
weight,
bias=None,
quantize=config.quantize,
weight_scale=weight_scale,
input_scale=input_scale,
))
return TensorParallelColumnLinear(
get_linear(
weight,
bias=None,
quantize=config.quantize,
weight_scale=weight_scale,
input_scale=input_scale,
)
)


def _load_experts(config, prefix, mat, weights):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen
import torch
import torch.distributed
from torch import nn
Expand All @@ -30,6 +29,7 @@

from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
FastLayerNorm,
PositionRotaryEmbedding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen

# Flash attention imports
import dropout_layer_norm
import torch
Expand All @@ -33,6 +31,7 @@
from lorax_server.adapters import AdapterBatchData
from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
MultiAdapterHead,
PositionRotaryEmbedding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen
import torch
import torch.distributed
from torch import nn
Expand All @@ -20,6 +19,7 @@
from lorax_server.adapters import AdapterBatchData
from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
FastLayerNorm,
MultiAdapterHead,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen

# Flash attention imports
import dropout_layer_norm
import torch
Expand All @@ -18,6 +16,7 @@
from lorax_server.adapters import AdapterBatchData
from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
MultiAdapterHead,
PositionRotaryEmbedding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen

# Flash attention imports
import dropout_layer_norm
import torch
Expand All @@ -19,6 +17,7 @@
from lorax_server.adapters import AdapterBatchData
from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
MultiAdapterHead,
PositionRotaryEmbedding,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen
import torch
import torch.distributed
from torch import nn
Expand All @@ -9,6 +8,7 @@

from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
FastLayerNorm,
PositionRotaryEmbedding,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN

from lorax_server.models.custom_modeling.utils import prepend
from lorax_server.utils import flash_attn, paged_attention
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
FastLayerNorm,
TensorParallelColumnLinear,
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/custom_modeling/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen
import torch
import torch.utils.checkpoint
from torch import nn
Expand All @@ -32,6 +31,7 @@
load_text_model,
load_vision_model,
)
from lorax_server.utils.attention.common import Seqlen


def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
Expand Down
3 changes: 1 addition & 2 deletions server/lorax_server/models/custom_modeling/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

from typing import List, Optional, Tuple

from lorax_server.utils.attention.common import Seqlen

import flash_attn_2_cuda
import torch
import torch.nn.functional as F
Expand All @@ -36,6 +34,7 @@
FlashLlamaForCausalLM,
FlashLlamaLayer,
)
from lorax_server.utils.attention.common import Seqlen


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
Expand Down
Loading

0 comments on commit 1d9f5d4

Please sign in to comment.