Skip to content

Commit

Permalink
Add arguments for selecting between attn types
Browse files Browse the repository at this point in the history
  • Loading branch information
klei22 committed Feb 3, 2025
1 parent 35c5959 commit 1fd1d62
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
3 changes: 3 additions & 0 deletions gpt_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class GPTConfig:
## Flash attention
disable_flash_attention: bool = False

# Attention Options
attention_variant: str = "causal"

# MLP Options
use_parallel_mlp: bool = False
mlp_variant: str = "mlp"
Expand Down
14 changes: 2 additions & 12 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ def create_shared_param_group(layer_type, config):
shared_group = []
layer_block = None

# For demonstration: we might have a config.attention_variants list
# that has multiple keys, e.g. ["causal","fancy"] or just ["causal"] repeated.
# We'll pick the attention class in a round-robin fashion as an example.
# (Adjust to your own logic if you want different mixing.)
attn_variants = getattr(config, "attention_variants", ["causal"]) # fallback if not present

for i in range(config.n_layer):

# Create a new layer block every "shared_size"
Expand All @@ -101,11 +95,7 @@ def create_shared_param_group(layer_type, config):
layer_block = get_mlp_instance(config)

elif layer_type == "attn":
# Example: select the i-th attention variant in a round-robin style
attn_type_index = i % len(attn_variants)
attn_type_name = attn_variants[attn_type_index]
# Look up the attention class from a dictionary
attn_cls = attention_dictionary[attn_type_name]
attn_cls = attention_dictionary[config.attention_variant]

# Instantiate an attention layer
layer_block = attn_cls(config, fire_pos_enc=fire_pos_enc)
Expand Down Expand Up @@ -150,7 +140,7 @@ def __init__(self, config, mlp=None, attn=None):

# Allow for sharing attn between blocks
if attn is None:
self.attn = CausalSelfAttention(config)
self.attn = attention_dictionary[config.attention_variant](config)
else:
self.attn = attn

Expand Down
11 changes: 11 additions & 0 deletions train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,17 @@ def parse_args():
## LearnedSplineActivation - lsa
model_group.add_argument("--lsa_num_knots", type=int, default=30)


# Attention Variations
model_group.add_argument(
"--attention_variant",
type=str,
default="causal",
choices=["causal"],
help="Which attention variant to use for the Transformer blocks."
)


# LINEAR VARIATIONS
linear_variants = ["linear", "bitlinear", "bitlinear_1p58", "bitlinear_optimized", "kan","quantized_linear"]
model_group.add_argument("--linear_variant_attn", type=str, default="linear", choices=linear_variants)
Expand Down

0 comments on commit 1fd1d62

Please sign in to comment.