Skip to content

Commit

Permalink
fix bugs for finetuning unispeech
Browse files Browse the repository at this point in the history
  • Loading branch information
cywang committed Apr 7, 2022
1 parent 295c961 commit e3043e2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/fairseq/models/unispeech/unispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def forward(self, **kwargs):
x = self.w2v_encoder(**kwargs)
return x

def remove_pretraining_modules(self):
self.w2v_encoder.proj = None

class Wav2VecEncoder(FairseqEncoder):
def __init__(self, cfg, task):
super().__init__(task.source_dictionary)
Expand Down
7 changes: 5 additions & 2 deletions src/fairseq/models/wav2vec/wav2vec2_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,16 +353,19 @@ def __init__(self, cfg: Wav2Vec2AsrConfig, output_size=None):
task = tasks.setup_task(w2v_args.task)
model = task.build_model(w2v_args.model)

model.remove_pretraining_modules()
if state is not None and not cfg.no_pretrained_weights:
model.load_state_dict(state["model"], strict=True)

model.remove_pretraining_modules()

super().__init__(task.source_dictionary)

d = w2v_args.model.encoder_embed_dim

self.w2v_model = model
if hasattr(model, 'w2v_encoder'):
self.w2v_model = model.w2v_encoder.w2v_model
else:
self.w2v_model = model

self.final_dropout = nn.Dropout(cfg.final_dropout)
self.freeze_finetune_updates = cfg.freeze_finetune_updates
Expand Down

0 comments on commit e3043e2

Please sign in to comment.