diff --git a/audiocraft/models/audiogen.py b/audiocraft/models/audiogen.py index 5cb88998..b4df536e 100644 --- a/audiocraft/models/audiogen.py +++ b/audiocraft/models/audiogen.py @@ -38,6 +38,10 @@ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, self.name = name self.compression_model = compression_model self.lm = lm + # Just to be safe, let's put everything in eval mode. + self.compression_model.eval() + self.lm.eval() + if max_duration is None: if hasattr(lm, 'cfg'): max_duration = lm.cfg.dataset.segment_duration # type: ignore diff --git a/config/model/lm/musicgen_lm.yaml b/config/model/lm/musicgen_lm.yaml index 5bc87a62..be1fbc14 100644 --- a/config/model/lm/musicgen_lm.yaml +++ b/config/model/lm/musicgen_lm.yaml @@ -18,7 +18,7 @@ codebooks_pattern: delays: [0, 0, 0, 0] music_lm: group_by: 2 - valle: + coarse_first: delays: [0, 0, 0] transformer_lm: