Skip to content

Commit

Permalink
Merge pull request #384 from klei22/add_learned_off_by_one_per_head
Browse files Browse the repository at this point in the history
Add learned offbyone per head
  • Loading branch information
gkielian authored Feb 4, 2025
2 parents 7bdb04d + bdae5b4 commit 9f2f909
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 6 deletions.
16 changes: 16 additions & 0 deletions explorations/lobo_sweep.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[
{
"max_iters": ["10000"],
"device": ["cuda"],
"dtype": ["bfloat16"],
"dataset": ["cosmopedia_100k"],
"compile": [true],
"softmax_variant_attn": ["strongermax"],
"tensorboard_run_name": ["obo_variations_learned"],
"strongermax_obo": ["0.0", "0.1", "1.0"],
"strongermax_use_learned_obo": [true],
"strongermax_use_learned_obo_per_head": [false],
"dropout": ["0.0","0.1"]
}
]

1 change: 1 addition & 0 deletions gpt_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class GPTConfig:

strongermax_obo: float = 0.0
strongermax_use_learned_obo: bool = False
strongermax_use_learned_obo_per_head: bool = False

strongermax_temperature_factor: float = 1.0
strongermax_use_learned_temperature_factor: bool = False
Expand Down
1 change: 1 addition & 0 deletions train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def parse_args():
### From https://www.evanmiller.org/attention-is-off-by-one.html
model_group.add_argument('--strongermax_obo', type=float, default=0.0)
model_group.add_argument('--strongermax_use_learned_obo', default=False, action=argparse.BooleanOptionalAction)
model_group.add_argument('--strongermax_use_learned_obo_per_head', default=False, action=argparse.BooleanOptionalAction)

### Temperature adjustment factor
model_group.add_argument('--strongermax_temperature_factor', type=float, default=1.0)
Expand Down
16 changes: 10 additions & 6 deletions variations/softmax_variations.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class Strongermax(nn.Module):
def __init__(self, config, dim=-1):
super().__init__()
self.dim = dim
self.n_head = config.n_head

# Strongermax Params
self.strength = config.strongermax_strength
Expand All @@ -208,7 +209,6 @@ def __init__(self, config, dim=-1):

# Set optional temperature (already divided by sqrt head dimension)
self.use_learned_temperature_factor = config.strongermax_use_learned_temperature_factor

if self.use_learned_temperature_factor:
self.temperature_factor = nn.Parameter(torch.Tensor([config.strongermax_temperature_factor]))
else:
Expand All @@ -218,7 +218,7 @@ def __init__(self, config, dim=-1):
self.iter_num = 0

if self.overflow_recompute:
assert self.xmax_guess is not None, "for overflow recompute, xmax_guess must be set"
assert self.xmax_guess is not None, "For overflow recompute, xmax_guess must be set"

# Input and Output Logging
self.softmax_io_logging = config.softmax_io_logging
Expand All @@ -228,10 +228,15 @@ def __init__(self, config, dim=-1):

# self.obo_offset default is 0.0, https://www.evanmiller.org/attention-is-off-by-one.html
self.use_learned_obo = config.strongermax_use_learned_obo
if self.use_learned_obo:
self.obo_offset = nn.Parameter(torch.Tensor([config.strongermax_obo]))
self.use_learned_obo_per_head = config.strongermax_use_learned_obo_per_head

if self.use_learned_obo_per_head:
self.obo_offset = nn.Parameter(torch.ones(self.n_head, 1, 1) * config.strongermax_obo)
else:
self.obo_offset = config.strongermax_obo
if self.use_learned_obo:
self.obo_offset = nn.Parameter(torch.Tensor([config.strongermax_obo]))
else:
self.obo_offset = config.strongermax_obo

def forward(self, x):
x_adj = x
Expand All @@ -243,7 +248,6 @@ def forward(self, x):
# Guessing correctly instead of subtracting real max can save a pass
# else we use real xmax
max_x = x_adj.max(dim=self.dim, keepdim=True).values

if self.overflow_recompute:
if (torch.max(x_adj - self.xmax_guess)) > self.overflow_recompute_value:
x_adj = x_adj - max_x
Expand Down

0 comments on commit 9f2f909

Please sign in to comment.