Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes for 65B and 70B runs #414

Merged
merged 130 commits into from
Mar 19, 2024
Merged

Changes for 65B and 70B runs #414

merged 130 commits into from
Mar 19, 2024

Conversation

dirkgr
Copy link
Member

@dirkgr dirkgr commented Jan 24, 2024

No description provided.

Copy link
Collaborator

@Muennighoff Muennighoff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GQA looks good to me

olmo/config.py Outdated
@@ -243,6 +242,14 @@ class ModelConfig(BaseConfig):
The number of self-attention heads.
"""

n_kv_heads: Optional[int] = None
"""
The number of heads to use for keys and values.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The number of heads to use for keys and values.
The number of heads to use for keys and values. Defaults to `n_heads`.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, I just can't click "commit" here for some reason.

olmo/config.py Outdated Show resolved Hide resolved
olmo/config.py Outdated Show resolved Hide resolved
olmo/config.py Outdated
Comment on lines 456 to 457
if hasattr(new_config, "optimizer"):
new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am learning that update_leacy_settings doesn't work anyways with settings you specify on the command line.

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an alternative approach that doesn't involve implementing update_legacy_settings().

olmo/config.py Outdated
@@ -309,8 +317,7 @@ class ModelConfig(BaseConfig):

multi_query_attention: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this Optional[bool], defaulting to None.

olmo/config.py Outdated
Comment on lines 439 to 440
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then here we could do this:

if self.multi_query_attention:
    self.n_kv_heads = 1
elif self.n_kv_heads is None:
    self.n_kv_heads = self.n_heads

olmo/config.py Outdated
self.n_kv_heads = self.n_heads

@classmethod
def update_legacy_settings(cls, config: D) -> D:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then this won't be needed.

@epwalsh epwalsh changed the title Some more changes for the 65B run Changes for 65B and 70B runs Mar 18, 2024
Copy link
Member Author

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is some leftover debug commenting? Other than that, looks good.

# Save metadata.
self._save_metadata(checkpoint_dir, upload_to=upload_to)

# Save config.
self._save_config(checkpoint_dir, upload_to=upload_to)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment explaining why this is now last?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

olmo/model.py Outdated
@@ -245,7 +260,7 @@ def __init__(self, config: ModelConfig, cache: BufferCache):
self.config = config
self.__cache = cache
# Warm up cache.
self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
# self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this anymore?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted.

scripts/beaker/mitchish70.sh Outdated Show resolved Hide resolved
SEED=3423
INIT=fan_in
RUN_NAME="fan-in-init-${SEED}"
ARGS="--run_name=${RUN_NAME} --data.seed=6198 --seed=${SEED} --model.init_fn=${INIT} --model.init_std=0.006 --model.init_cutoff_factor=3 --device_train_microbatch_size=4 --model.flash_attention=true --fused_loss=true --evaluators=[] --stop_at=500 --wandb.group=mitchish70-ablate-init --save_interval_ephemeral=100"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this isn't the final config anyways.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaned up in 92d2a08.

@epwalsh epwalsh merged commit 74de51d into main Mar 19, 2024
11 checks passed
@epwalsh epwalsh deleted the mitchish65-2 branch March 19, 2024 15:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants