diff --git a/CHANGELOG.md b/CHANGELOG.md index 6036b72f..c1cf0c0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). Adding stereo models. +Removed compression model state from the LM checkpoints, for consistency, it +should always be loaded from the original `compression_model_checkpoint`. + ## [1.1.0] - 2023-11-06 diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py index 8b7acf22..840aa263 100644 --- a/audiocraft/__init__.py +++ b/audiocraft/__init__.py @@ -23,4 +23,4 @@ # flake8: noqa from . import data, modules, models -__version__ = '1.2.0a1' +__version__ = '1.2.0a2' diff --git a/audiocraft/solvers/musicgen.py b/audiocraft/solvers/musicgen.py index 2439da33..72b65338 100644 --- a/audiocraft/solvers/musicgen.py +++ b/audiocraft/solvers/musicgen.py @@ -25,7 +25,7 @@ from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition from ..utils.cache import CachedBatchWriter, CachedBatchLoader from ..utils.samples.manager import SampleManager -from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once +from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once, model_hash class MusicGenSolver(base.StandardSolver): @@ -143,7 +143,7 @@ def build_model(self) -> None: # initialize optimization self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates) - self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler') + self.register_stateful('model', 'optimizer', 'lr_scheduler') self.register_best_state('model') self.autocast_dtype = { 'float16': torch.float16, 'bfloat16': torch.bfloat16 @@ -181,6 +181,22 @@ def load_state_dict(self, state: dict) -> None: key = prefix + key assert key not in model_state model_state[key] = value + if 'compression_model' in state: + # We used to store the `compression_model` state in the checkpoint, however + # this is in general not needed, as the compression model should always be readable + # from the original `cfg.compression_model_checkpoint` location. + compression_model_state = state.pop('compression_model') + before_hash = model_hash(self.compression_model) + self.compression_model.load_state_dict(compression_model_state) + after_hash = model_hash(self.compression_model) + if before_hash != after_hash: + raise RuntimeError( + "The compression model state inside the checkpoint is different" + " from the one obtained from compression_model_checkpoint..." + "We do not support altering the compression model inside the LM " + "checkpoint as parts of the code, in particular for running eval post-training " + "will use the compression_model_checkpoint as the source of truth.") + super().load_state_dict(state) def load_from_pretrained(self, name: str):