diff --git a/gpt_conf.py b/gpt_conf.py index 7fbbdd48f..daad3d9aa 100644 --- a/gpt_conf.py +++ b/gpt_conf.py @@ -70,6 +70,9 @@ class GPTConfig: ## Flash attention disable_flash_attention: bool = False + # Attention Options + attention_variant: str = "causal" + # MLP Options use_parallel_mlp: bool = False mlp_variant: str = "mlp" diff --git a/model.py b/model.py index ddbc479bb..691df052b 100644 --- a/model.py +++ b/model.py @@ -83,12 +83,6 @@ def create_shared_param_group(layer_type, config): shared_group = [] layer_block = None - # For demonstration: we might have a config.attention_variants list - # that has multiple keys, e.g. ["causal","fancy"] or just ["causal"] repeated. - # We'll pick the attention class in a round-robin fashion as an example. - # (Adjust to your own logic if you want different mixing.) - attn_variants = getattr(config, "attention_variants", ["causal"]) # fallback if not present - for i in range(config.n_layer): # Create a new layer block every "shared_size" @@ -101,11 +95,7 @@ def create_shared_param_group(layer_type, config): layer_block = get_mlp_instance(config) elif layer_type == "attn": - # Example: select the i-th attention variant in a round-robin style - attn_type_index = i % len(attn_variants) - attn_type_name = attn_variants[attn_type_index] - # Look up the attention class from a dictionary - attn_cls = attention_dictionary[attn_type_name] + attn_cls = attention_dictionary[config.attention_variant] # Instantiate an attention layer layer_block = attn_cls(config, fire_pos_enc=fire_pos_enc) @@ -150,7 +140,7 @@ def __init__(self, config, mlp=None, attn=None): # Allow for sharing attn between blocks if attn is None: - self.attn = CausalSelfAttention(config) + self.attn = attention_dictionary[config.attention_variant](config) else: self.attn = attn diff --git a/train_args.py b/train_args.py index d41269011..4e6d71cff 100644 --- a/train_args.py +++ b/train_args.py @@ -269,6 +269,17 @@ def parse_args(): ## LearnedSplineActivation - lsa model_group.add_argument("--lsa_num_knots", type=int, default=30) + + # Attention Variations + model_group.add_argument( + "--attention_variant", + type=str, + default="causal", + choices=["causal"], + help="Which attention variant to use for the Transformer blocks." + ) + + # LINEAR VARIATIONS linear_variants = ["linear", "bitlinear", "bitlinear_1p58", "bitlinear_optimized", "kan","quantized_linear"] model_group.add_argument("--linear_variant_attn", type=str, default="linear", choices=linear_variants)