Skip to content

Commit

Permalink
Fix grad checkpointing config in sampling CLI (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Dec 18, 2023
1 parent a95f4b5 commit 3667e27
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def _get_midi_path(midi_path: str | None) -> str:
return midi_path


# TODO: Add arg for supressing the audio conversion, and commands for changing
# the sampling params from the cli
def sample(args):
"""Entrypoint for sampling"""

Expand Down Expand Up @@ -154,6 +152,7 @@ def sample(args):

model_config = ModelConfig(**load_model_config(model_name))
model_config.set_vocab_size(tokenizer.vocab_size)
model_config.grad_checkpoint = False
model = TransformerLM(model_config).to(device)

if args.trunc + args.l > model_config.max_seq_len:
Expand Down

0 comments on commit 3667e27

Please sign in to comment.