diff --git a/example/model.py b/example/model.py index b7b7443..b6a3764 100644 --- a/example/model.py +++ b/example/model.py @@ -22,7 +22,7 @@ class GPTConfig: n_embd: int = 768 dropout: float = 0.0 bias: bool = False - attention = "flash_attention" # "standard_attention" + attention = "standard_attention" # "standard_attention", "flash_attention" # Masked Multi-Head Self-Attention