Skip to content

Commit 2cca810

Browse files
LauraGPTlyblsgoR1ckShi
authored
Funasr1.0 (#1275)
* funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li <[email protected]> Co-authored-by: shixian.shi <[email protected]> * update with main (#1264) * Funasr1.0 (#1261) * funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li <[email protected]> Co-authored-by: shixian.shi <[email protected]> --------- Co-authored-by: Yabin Li <[email protected]> Co-authored-by: shixian.shi <[email protected]> * bug fix --------- Co-authored-by: Yabin Li <[email protected]> Co-authored-by: shixian.shi <[email protected]> * funasr1.0 sanm scama * funasr1.0 infer_after_finetune * funasr1.0 fsmn-vad bug fix * funasr1.0 fsmn-vad bug fix * funasr1.0 fsmn-vad bug fix * funasr1.0 finetune * funasr1.0 finetune * funasr1.0 finetune --------- Co-authored-by: Yabin Li <[email protected]> Co-authored-by: shixian.shi <[email protected]>
1 parent 12496e5 commit 2cca810

File tree

6 files changed

+21
-11
lines changed

6 files changed

+21
-11
lines changed

examples/industrial_data_pretraining/paraformer/finetune.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ python funasr/bin/train.py \
1111
+model_revision="v2.0.2" \
1212
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
1313
+valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
14-
++dataset_conf.batch_size=2 \
14+
++dataset_conf.batch_size=64 \
1515
++dataset_conf.batch_type="example" \
1616
++train_conf.max_epoch=2 \
17+
++dataset_conf.num_workers=4 \
1718
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
18-
+device="cpu" \
1919
+debug="true"

funasr/auto/auto_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(self, **kwargs):
132132
self.punc_kwargs = punc_kwargs
133133
self.spk_model = spk_model
134134
self.spk_kwargs = spk_kwargs
135-
self.model_path = kwargs["model_path"]
135+
self.model_path = kwargs.get("model_path", "./")
136136

137137

138138
def build_model(self, **kwargs):

funasr/bin/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def main_hydra(kwargs: DictConfig):
4040

4141

4242
def main(**kwargs):
43-
43+
print(kwargs)
4444
# set random seed
4545
tables.print()
4646
set_all_random_seed(kwargs.get("seed", 0))

funasr/datasets/audio_datasets/samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, dataset,
2828
self.shuffle = shuffle and is_training
2929

3030
def __len__(self):
31-
return self.total_samples
31+
return (self.total_samples-1) // self.batch_size + 1
3232

3333
def set_epoch(self, epoch):
3434
np.random.seed(epoch)

funasr/models/fsmn_vad_streaming/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ def __init__(self,
255255
self.waveform = None
256256
self.last_drop_frames = 0
257257

258-
259258
@tables.register("model_classes", "FsmnVADStreaming")
260259
class FsmnVADStreaming(nn.Module):
261260
"""
@@ -500,7 +499,6 @@ def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
500499
# # reset class variables and clear the dict for the next query
501500
# self.AllResetDetection()
502501
return segments
503-
504502

505503
def init_cache(self, cache: dict = {}, **kwargs):
506504

funasr/train_utils/trainer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,17 @@ def run(self):
147147
for epoch in range(self.start_epoch, self.max_epoch + 1):
148148

149149
self._train_epoch(epoch)
150+
150151

152+
if self.use_ddp or self.use_fsdp:
153+
dist.barrier()
154+
151155
self._validate_epoch(epoch)
152-
156+
157+
if self.use_ddp or self.use_fsdp:
158+
dist.barrier()
159+
160+
153161
if self.rank == 0:
154162
self._save_checkpoint(epoch)
155163

@@ -164,7 +172,9 @@ def run(self):
164172

165173
if self.use_ddp or self.use_fsdp:
166174
dist.barrier()
167-
self.writer.close()
175+
176+
if self.writer:
177+
self.writer.close()
168178

169179

170180
def _train_epoch(self, epoch):
@@ -230,6 +240,8 @@ def _train_epoch(self, epoch):
230240
continue
231241

232242
# Execute an optimization step (update model parameters)
243+
if self.use_ddp or self.use_fsdp:
244+
dist.barrier()
233245
self.optim.step()
234246
self.scheduler.step()
235247
# Clear gradients for the next accumulation stage
@@ -244,7 +256,7 @@ def _train_epoch(self, epoch):
244256
pbar.update(1)
245257
if self.local_rank == 0:
246258
description = (
247-
f"Epoch: {epoch}/{self.max_epoch}, "
259+
f"Train epoch: {epoch}/{self.max_epoch}, "
248260
f"step {batch_idx}/{len(self.dataloader_train)}, "
249261
f"{speed_stats}, "
250262
f"(loss: {loss.detach().cpu().item():.3f}), "
@@ -306,7 +318,7 @@ def _validate_epoch(self, epoch):
306318
pbar.update(1)
307319
if self.local_rank == 0:
308320
description = (
309-
f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
321+
f"validation epoch: {epoch}/{self.max_epoch}, "
310322
f"step {batch_idx}/{len(self.dataloader_train)}, "
311323
f"{speed_stats}, "
312324
f"(loss: {loss.detach().cpu().item():.3f}), "

0 commit comments

Comments
 (0)