Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpeechLM Update #12430

Open
wants to merge 165 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
165 commits
Select commit Hold shift + click to select a range
d967c51
fix type bugs
stevehuang52 Sep 16, 2024
84eaa59
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
stevehuang52 Sep 16, 2024
22217d0
Merge remote-tracking branch 'origin/main' into slm_v2
stevehuang52 Sep 19, 2024
799a5ec
Merge remote-tracking branch 'origin/main' into slm_v2
stevehuang52 Sep 19, 2024
568f073
Update mixin.py
stevehuang52 Sep 20, 2024
f91030d
Apply isort and black reformatting
stevehuang52 Sep 20, 2024
c9256ca
Update mixin.py
stevehuang52 Sep 21, 2024
e41e4f9
Apply isort and black reformatting
stevehuang52 Sep 21, 2024
2a70bd3
Update mixin.py
stevehuang52 Sep 23, 2024
725fd88
Merge branch 'main' into slm_v2
stevehuang52 Sep 23, 2024
de70200
add datamodule
stevehuang52 Sep 25, 2024
98a2673
Merge remote-tracking branch 'origin/main' into slm_v2
stevehuang52 Oct 1, 2024
a00e5dd
add speechlm peft train, continue train, validation and misc
stevehuang52 Oct 8, 2024
52a1b0e
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Oct 8, 2024
9d57bc6
resolve merge confict
stevehuang52 Oct 8, 2024
ae3c5be
update datamodule
stevehuang52 Oct 8, 2024
c2ef987
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Oct 8, 2024
dc951e2
add script
stevehuang52 Oct 8, 2024
248f4e2
fix speechlm inference
stevehuang52 Oct 14, 2024
c42e17e
update
stevehuang52 Oct 14, 2024
9ec5a96
mergin origin/main
stevehuang52 Oct 14, 2024
90dcb6a
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Oct 15, 2024
0a93523
fix tp support
stevehuang52 Oct 15, 2024
5c87470
Apply isort and black reformatting
stevehuang52 Oct 15, 2024
934b0db
Apply isort and black reformatting
artbataev Oct 15, 2024
e7af9a7
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Oct 22, 2024
0c32fa5
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Oct 29, 2024
9d98e89
update
stevehuang52 Oct 30, 2024
b659912
refactor
stevehuang52 Nov 1, 2024
3653320
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Nov 13, 2024
31bc21c
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Nov 14, 2024
2256602
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Nov 18, 2024
6a18e9a
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Nov 19, 2024
5ca4729
updated for latest changes in main, tested for interleave audio-text …
stevehuang52 Nov 20, 2024
31378a5
fix nemo tp>1, and lhotse tp=1
stevehuang52 Nov 20, 2024
6bf47f7
update prompt formatter
stevehuang52 Nov 26, 2024
ccd78ba
update cfg
stevehuang52 Nov 26, 2024
239d0f9
update asr
stevehuang52 Nov 26, 2024
dd9aba4
update tokenizer
stevehuang52 Nov 26, 2024
2bd5ce7
update prompt format fn
stevehuang52 Nov 26, 2024
34531ca
update prompt format fn
stevehuang52 Nov 26, 2024
41c36ea
update lhotse
stevehuang52 Nov 26, 2024
2f240bf
update data sampler
stevehuang52 Nov 26, 2024
644c0b8
update data sampler
stevehuang52 Nov 26, 2024
cb2661b
update lhotse
stevehuang52 Nov 26, 2024
3db0764
update lhotse
stevehuang52 Nov 26, 2024
a932673
update lhotse dataset
stevehuang52 Nov 26, 2024
bd8e3e8
debug
stevehuang52 Nov 26, 2024
44bfb2c
debug
stevehuang52 Nov 26, 2024
43bb800
debug
stevehuang52 Nov 27, 2024
08acb79
debug
stevehuang52 Nov 27, 2024
ea538ed
debug
stevehuang52 Nov 27, 2024
4045085
update
stevehuang52 Nov 27, 2024
108eb9d
update
stevehuang52 Nov 27, 2024
7d9f855
add wandb to cfg yaml
stevehuang52 Nov 27, 2024
1338245
add support for max_time_per_run
stevehuang52 Dec 2, 2024
1bc0210
add debug info
stevehuang52 Dec 2, 2024
988e585
add check
stevehuang52 Dec 3, 2024
7a7052d
update sampler
stevehuang52 Dec 3, 2024
e8ea92f
update dataset
stevehuang52 Dec 4, 2024
b3b0f33
update dataset
stevehuang52 Dec 4, 2024
63590b3
try fix peft resume
stevehuang52 Dec 4, 2024
e07d88a
try fix peft resume
stevehuang52 Dec 4, 2024
bc1d434
clean up
stevehuang52 Dec 4, 2024
354c2a9
refactor
stevehuang52 Dec 5, 2024
811ac6c
refactor
stevehuang52 Dec 5, 2024
55d1bb8
fix val dl name
stevehuang52 Dec 5, 2024
1425f82
fix val/test names/outputs
stevehuang52 Dec 5, 2024
d187a03
fix lhotse name
stevehuang52 Dec 5, 2024
1dc2d91
fix loss
stevehuang52 Dec 6, 2024
47c07c9
update data statedict
stevehuang52 Dec 9, 2024
ad35350
update data module
stevehuang52 Dec 10, 2024
5447525
update
stevehuang52 Dec 16, 2024
ea6618e
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Dec 16, 2024
d94b7d1
fix
stevehuang52 Dec 18, 2024
d51a675
fix data
stevehuang52 Dec 18, 2024
366dad6
fix data
stevehuang52 Dec 18, 2024
01f3d5a
fix data
stevehuang52 Dec 19, 2024
63b7097
fix data
stevehuang52 Jan 2, 2025
d9a1206
fix peft load ckpt
stevehuang52 Jan 3, 2025
2b1ecd9
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Jan 6, 2025
8f84111
fix megatron batch sampler
stevehuang52 Jan 7, 2025
8c21769
fix and clean up
stevehuang52 Jan 8, 2025
ae9009b
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Jan 13, 2025
bea40e5
fix for lhotse
stevehuang52 Jan 14, 2025
674289d
clean up
stevehuang52 Jan 14, 2025
809a4c4
refactor
stevehuang52 Jan 14, 2025
7844ae8
clean up
stevehuang52 Jan 14, 2025
fa34d07
Merge branch 'main' into heh/speechlm_nemo2.0
stevehuang52 Jan 14, 2025
bfa9599
clean up
stevehuang52 Jan 14, 2025
bdf8d26
clean up
stevehuang52 Jan 14, 2025
25062b3
clean up
stevehuang52 Jan 15, 2025
c9b5da4
clean up
stevehuang52 Jan 15, 2025
aa935ee
refactor data
stevehuang52 Jan 16, 2025
9e56f03
refactor data
stevehuang52 Jan 16, 2025
6949ef5
update llama prompt
stevehuang52 Jan 16, 2025
8769553
update llama prompt
stevehuang52 Jan 16, 2025
417afaa
update lhotse
stevehuang52 Jan 16, 2025
01cf0dc
diable warning
stevehuang52 Jan 16, 2025
b463408
revert changes
stevehuang52 Jan 16, 2025
0a3fccc
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Jan 16, 2025
bcd5711
refactor dataset
stevehuang52 Jan 17, 2025
1cdf312
clean up
stevehuang52 Jan 17, 2025
fb36e06
clean up
stevehuang52 Jan 23, 2025
7c2f159
refactor and clean up
stevehuang52 Jan 23, 2025
1e64395
clean up
stevehuang52 Jan 23, 2025
f96bfd7
clean up
stevehuang52 Jan 23, 2025
eaaa5f4
fix codeQL
stevehuang52 Jan 23, 2025
4a28305
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Jan 23, 2025
a130e13
fix codeQL
stevehuang52 Jan 23, 2025
855e7e4
address comments
stevehuang52 Jan 28, 2025
9318efd
fix when consumed samples more than total samples
stevehuang52 Jan 29, 2025
da2f8bf
add ci test
stevehuang52 Jan 29, 2025
687903c
Merge branch 'main' into heh/speechlm_nemo2.0
ko3n1g Jan 30, 2025
a8cb982
Update .github/workflows/cicd-main.yml
stevehuang52 Jan 30, 2025
cabd8f1
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Jan 30, 2025
e95b44c
disable test_subword_decoding_greedy_forward_hypotheses
stevehuang52 Jan 30, 2025
498db3d
fix ci
stevehuang52 Jan 30, 2025
1dab028
fix ci
stevehuang52 Jan 30, 2025
f653988
update for CI
stevehuang52 Jan 31, 2025
712dffa
fix megatron sampler and clean up
stevehuang52 Jan 31, 2025
a277e33
clean up
stevehuang52 Feb 3, 2025
610a1c4
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Feb 3, 2025
235686d
update ckpt import from HF
stevehuang52 Feb 3, 2025
8778e68
fix dist ckpt load
stevehuang52 Feb 3, 2025
93b2123
add exception handling
stevehuang52 Feb 3, 2025
d317129
refactor loading pretrained llm
stevehuang52 Feb 4, 2025
bca15f0
refactor and clean up
stevehuang52 Feb 4, 2025
4c2825d
Merge branch 'main' into heh/speechlm_nemo2.0
stevehuang52 Feb 4, 2025
c04abe4
refactor and fix
stevehuang52 Feb 6, 2025
721aba3
Merge branch 'heh/speechlm_nemo2.0' of https://github.com/NVIDIA/NeMo…
stevehuang52 Feb 6, 2025
bf42258
Merge remote-tracking branch 'origin/main' into heh/speechlm_nemo2.0
stevehuang52 Feb 6, 2025
54dfa56
refactor
stevehuang52 Feb 6, 2025
24f316d
refactor for peft
stevehuang52 Feb 7, 2025
d825e83
fix CI config
stevehuang52 Feb 8, 2025
b66f195
debug
stevehuang52 Feb 10, 2025
095c24e
added multimodal conversation
stevehuang52 Feb 11, 2025
66d3ece
update
stevehuang52 Feb 12, 2025
3901303
debug
stevehuang52 Feb 12, 2025
d9b29c4
debug
stevehuang52 Feb 13, 2025
07755a4
update cfg
stevehuang52 Feb 13, 2025
3b3ae9e
clean up
stevehuang52 Feb 13, 2025
1dcbd6c
Merge remote-tracking branch 'origin/main' into heh/speechlm_dev
stevehuang52 Feb 14, 2025
1fc1ffd
add cp and fix tp
stevehuang52 Feb 18, 2025
b93bcf0
add missing __init__
stevehuang52 Feb 18, 2025
144e39f
fix
stevehuang52 Feb 18, 2025
c0040d3
fix import ckpt
stevehuang52 Feb 18, 2025
c6791b0
update io
stevehuang52 Feb 19, 2025
74a04be
fix hf tokenizer remove_special_tokens
stevehuang52 Feb 20, 2025
957e729
refactor
stevehuang52 Feb 20, 2025
dd725e7
comment out lhotse assert
stevehuang52 Feb 20, 2025
04c6e77
update cfg
stevehuang52 Feb 20, 2025
902b0f6
update cfg
stevehuang52 Feb 21, 2025
d05cbd4
refactor and update inference
stevehuang52 Feb 25, 2025
b2a853f
update infer
stevehuang52 Feb 25, 2025
c0da97f
update
stevehuang52 Feb 25, 2025
f2f288a
update cfg
stevehuang52 Feb 25, 2025
fd6de0d
fix peft trainable params and update
stevehuang52 Feb 26, 2025
26bb5a4
add support for whisper encoder
stevehuang52 Feb 26, 2025
601364e
Merge remote-tracking branch 'origin/main' into heh/speechlm_dev
stevehuang52 Feb 26, 2025
b88f0c0
update
stevehuang52 Feb 28, 2025
d7eca13
clean up
stevehuang52 Mar 3, 2025
64e027e
clean up
stevehuang52 Mar 3, 2025
a66d99b
Merge remote-tracking branch 'origin/main' into heh/speechlm_dev
stevehuang52 Mar 3, 2025
0188d8d
clean up
stevehuang52 Mar 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions examples/speechlm/conf/salm/salm_llama3-8b_fc_fc_peft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ data:
common:
global_batch_size: 2
micro_batch_size: 2
max_seq_length: 2048
max_seq_length: 4096
min_seq_length: 1
sample_rate: 16000
end_string: null
Expand Down Expand Up @@ -120,10 +120,11 @@ model:
_target_: nemo.collections.asr.models.EncDecMultiTaskModel
pretrained_model: "nvidia/canary-1b"
target_module: "encoder"
sample_rate: ${data.common.sample_rate}
spec_augment_config:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
time_masks: 10 # set to zero to disable it
freq_masks: 0 # set to zero to disable it, otherwise use something like 2
time_masks: 0 # set to zero to disable it, otherwise use something like 10
freq_width: 27
time_width: 0.05

Expand Down Expand Up @@ -214,23 +215,25 @@ trainer:
devices: -1
accelerator: gpu
num_nodes: 1
max_epochs: 1000 # used to keep epoch logging correctly, but training will stop based on max_steps
max_epochs: -1
max_steps: 1000000 # 1M steps
log_every_n_steps: 2 # frequency with which training steps are logged
val_check_interval: 1.0 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
accumulate_grad_batches: 1
log_every_n_steps: 10 # frequency with which training steps are logged
val_check_interval: 2000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
num_sanity_val_steps: 0
sync_batchnorm: true # used for convolution modules like FC

strategy:
_target_: nemo.collections.speechlm.strategies.SpeechLMMegatronStrategy
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
context_parallel_size: 1
ckpt_async_save: true

callbacks:
checkpoint:
_target_: nemo.lightning.pytorch.callbacks.ModelCheckpoint
filename: '${name}--{${callbacks.checkpoint.monitor}:.3f}-{step}-{epoch}'
filename: '${name}--{${callbacks.checkpoint.monitor}:.5f}-{step}'
monitor: "val_loss"
mode: "min"
save_last: true
Expand Down
19 changes: 14 additions & 5 deletions examples/speechlm/conf/salm/salm_llama3-8b_fc_linear_peft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ data:
common:
global_batch_size: 2
micro_batch_size: 2
max_seq_length: 2048
max_seq_length: 4096
min_seq_length: 1
sample_rate: 16000
end_string: null
Expand Down Expand Up @@ -118,6 +118,13 @@ model:
speech_encoder:
pretrained_model: "stt_en_fastconformer_transducer_large"
target_module: "encoder"
sample_rate: ${data.common.sample_rate}
spec_augment_config:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 0 # set to zero to disable it
time_masks: 0 # set to zero to disable it
freq_width: 27
time_width: 0.05

modality_adapter:
input_key_from: "d_model" # attribute of model dim in the speech model
Expand Down Expand Up @@ -161,23 +168,25 @@ trainer:
devices: -1
accelerator: gpu
num_nodes: 1
max_epochs: 1000 # used to keep epoch logging correctly, but training will stop based on max_steps
max_epochs: -1
max_steps: 1000000 # 1M steps
log_every_n_steps: 2 # frequency with which training steps are logged
val_check_interval: 1.0 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
accumulate_grad_batches: 1
log_every_n_steps: 10 # frequency with which training steps are logged
val_check_interval: 2000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
num_sanity_val_steps: 0
sync_batchnorm: true # used for convolution modules like FC

strategy:
_target_: nemo.collections.speechlm.strategies.SpeechLMMegatronStrategy
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
context_parallel_size: 1
ckpt_async_save: true

callbacks:
checkpoint:
_target_: nemo.lightning.pytorch.callbacks.ModelCheckpoint
filename: '${name}--{${callbacks.checkpoint.monitor}:.3f}-{step}-{epoch}'
filename: '${name}--{${callbacks.checkpoint.monitor}:.5f}-{step}'
monitor: "val_loss"
mode: "min"
save_last: true
Expand Down
270 changes: 270 additions & 0 deletions examples/speechlm/conf/salm/salm_llama3.1-8b_fc_fc_peft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
name: megatron_audio_gpt_peft


############ Data ############
data:
common:
global_batch_size: 2
micro_batch_size: 2
max_seq_length: 4096
min_seq_length: 1
sample_rate: 16000
end_string: null
context_key: 'context'
answer_key: 'answer'
prompt_format: null
prompt_template: "Q: {context}\nA: {answer}" # fstring to use for assistant prompt.
separate_prompt_and_response_with_newline: False
truncation_field: 'context'
add_eos: true
add_sep: false
add_bos: false
tokens_to_generate: 128
audio_locator: null
add_boa_eoa: false

train_ds:
# Example of how to specify paths to multiple datasets
# manifest_filepath:
# - /path/to/squad.jsonl
# - /path/to/mnli.jsonl
# - /path/to/boolq.jsonl
# Example of how each dataset is formatted
# {'audio_filepath': 'audio1.wav', 'offset': 0.0, 'duration': 12.3, 'context': 'transcribe this audio', 'answer': 'I have a dream...'}
# the 'answer' field can also be 'text', and a default 'context' field is added if missing in manigests, so as to work with ASR manifests
manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data.
global_batch_size: ${data.common.global_batch_size}
micro_batch_size: ${data.common.micro_batch_size}
shuffle: True
num_workers: 0
pin_memory: True
max_seq_length: ${data.common.max_seq_length}
min_seq_length: ${data.common.min_seq_length}
drop_last: True
# Notably, the data weights are controlled by either bucketing_weights
# or concat_sampling_probabilities depending on the dataset type (tar and
# non-tar).
concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random'
context_key: ${data.common.context_key}
answer_key: ${data.common.answer_key}
end_string: ${data.common.end_string}
add_eos: ${data.common.add_eos}
add_sep: ${data.common.add_sep}
add_bos: ${data.common.add_bos}
separate_prompt_and_response_with_newline: ${data.common.separate_prompt_and_response_with_newline}
truncation_field: ${data.common.truncation_field} # Options: ['context', 'answer']
prompt_template: ${data.common.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
# ASR configs
sample_rate: ${data.common.sample_rate}
max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset
min_duration: 0.1
# tarred datasets
is_concat: false
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "fully_randomized"
bucketing_batch_size: null
audio_locator: ${data.common.audio_locator}
prompt_format: ${data.common.prompt_format}

validation_ds:
manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
global_batch_size: ${data.common.global_batch_size}
micro_batch_size: ${data.common.micro_batch_size}
shuffle: False
num_workers: 0
pin_memory: True
max_seq_length: ${data.common.max_seq_length}
min_seq_length: ${data.common.min_seq_length}
drop_last: true # no effect, the dataloader will drop last for train and validation anyway
context_key: ${data.common.context_key}
answer_key: ${data.common.answer_key}
add_eos: ${data.common.add_eos}
end_string: ${data.common.end_string}
add_sep: ${data.common.add_sep}
add_bos: ${data.common.add_bos}
separate_prompt_and_response_with_newline: ${data.common.separate_prompt_and_response_with_newline}
output_file_path_prefix: null # Prefix of the file to write predictions to.
truncation_field: ${data.common.truncation_field} # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: ${data.common.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
tokens_to_generate: ${data.common.tokens_to_generate}
write_predictions_to_file: False
# ASR configs
sample_rate: ${data.common.sample_rate}
audio_locator: ${data.common.audio_locator}
prompt_format: ${data.common.prompt_format}

log_every_n_steps: 10
metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss', 'wer', 'bleu', 'rouge']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
num_classes: null


############ Model ############
model:
freeze_language_model: true
freeze_speech_model: false
freeze_modality_adapter: false

llm:
pretrained_model: "meta-llama/Llama-3.1-8B"
_target_: nemo.collections.llm.LlamaModel
config:
_target_: nemo.collections.llm.Llama31Config8B

speech_encoder:
_target_: nemo.collections.asr.models.EncDecMultiTaskModel
pretrained_model: "nvidia/canary-1b"
target_module: "encoder"
sample_rate: ${data.common.sample_rate}
spec_augment_config:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 0 # set to zero to disable it, otherwise use something like 2
time_masks: 0 # set to zero to disable it, otherwise use something like 10
freq_width: 27
time_width: 0.05

modality_adapter:
input_key_from: "d_model" # attribute of model dim in the speech model
input_key_to: "feat_in" # attribute of input dim in the modality adapter
output_key: "feat_out" # attrubuite of output dim in the modality adapter
config:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: -1
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 2
d_model: 512

# Sub-sampling parameters
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 1 # must be power of 2 for striding and vggnet
subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model
causal_downsampling: false

# Reduction parameters: Can be used to add another subsampling layer at a given position.
# Having a 2x reduction will speedup the training and inference speech while keeping similar WER.
# Adding it at the end will give the best WER while adding it at the beginning will give the best speedup.
reduction: null # pooling, striding, or null
reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder
reduction_factor: 1

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
att_context_style: regular # regular or chunked_limited
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
# conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
# null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1 # The dropout used before the encoder
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 1

peft:
_target_: nemo.collections.llm.peft.LoRA
dim: 32


############ Optimizer ############
optim:
_target_: nemo.lightning.MegatronOptimizerModule
config:
_target_: megatron.core.optimizer.OptimizerConfig
optimizer: adam
lr: 1e-4
clip_grad: 1.0
weight_decay: 0.0001
lr_scheduler:
_target_: nemo.lightning.pytorch.optim.CosineAnnealingScheduler
max_steps: ${trainer.max_steps}
warmup_steps: 250
constant_steps: 10000
min_lr: 5e-5

############ Trainer ############

# Set this to "DD:HH:MM:SS" format to limit the max time for this job
# If `max_time_per_run` is set, `strategy.ckpt_async_save` must be set to false
max_time_per_run: null

trainer:
# _target_: nemo.lightning.Trainer
devices: -1
accelerator: gpu
num_nodes: 1
max_epochs: -1
max_steps: 1000000 # 1M steps
accumulate_grad_batches: 1
log_every_n_steps: 10 # frequency with which training steps are logged
val_check_interval: 2000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
num_sanity_val_steps: 0
sync_batchnorm: true # used for convolution modules like FC

strategy:
_target_: nemo.collections.speechlm.strategies.SpeechLMMegatronStrategy
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
context_parallel_size: 1
ckpt_async_save: true

callbacks:
checkpoint:
_target_: nemo.lightning.pytorch.callbacks.ModelCheckpoint
filename: '${name}--{${callbacks.checkpoint.monitor}:.5f}-{step}'
monitor: "val_loss"
mode: "min"
save_last: true
save_top_k: 1
save_weights_only: false
always_save_context: true

plugins:
_target_: nemo.lightning.MegatronMixedPrecision
precision: "bf16-mixed"
autocast_enabled: null

############ AutoResume ############
resume:
_target_: nemo.collections.speechlm.utils.resume.SpeechLMAutoResume
resume_from_directory: null
resume_from_path: null
adapter_path: null
resume_if_exists: true
resume_past_end: false
resume_ignore_no_checkpoint: true


############ Logging ############
logger:
_target_: nemo.lightning.NeMoLogger
log_dir: null # default to ./nemo_experiments
name: ${name}
wandb:
_target_: lightning.pytorch.loggers.WandbLogger
project: null
name: ${logger.name}
resume: false

Loading
Loading