Skip to content

Commit

Permalink
Add transformers hub usage mbd (facebookresearch#202)
Browse files Browse the repository at this point in the history
* add transformers hub usage for mbd models

* small bug nit correction

* correct style
  • Loading branch information
ylacombe authored Aug 28, 2023
1 parent 3901c1a commit 4a4578a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
13 changes: 9 additions & 4 deletions models/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,17 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_di
return model


def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
return _get_state_dict(file_or_url_or_id, filename="all_in_one.pt", cache_dir=cache_dir)
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
filename: tp.Optional[str] = None,
cache_dir: tp.Optional[str] = None):
return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)


def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
pkg = load_mbd_ckpt(file_or_url_or_id, cache_dir=cache_dir)
def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
device='cpu',
filename: tp.Optional[str] = None,
cache_dir: tp.Optional[str] = None):
pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
models = []
processors = []
cfgs = []
Expand Down
10 changes: 6 additions & 4 deletions models/multibanddiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ def get_mbd_musicgen(device=None):
"""Load our diffusion models trained for MusicGen."""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = 'https://dl.fbaipublicfiles.com/encodec/Diffusion/mbd_musicgen_32khz.th'
path = 'facebook/multiband-diffusion'
filename = 'mbd_musicgen_32khz.th'
name = 'facebook/musicgen-small'
codec_model = load_compression_model(name, device=device)
models, processors, cfgs = load_diffusion_models(path, device=device)
models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
DPs = []
for i in range(len(models)):
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
Expand Down Expand Up @@ -102,8 +103,9 @@ def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
'//pretrained/facebook/encodec_24khz', device=device)
codec_model.set_num_codebooks(n_q)
codec_model = codec_model.to(device)
path = f'https://dl.fbaipublicfiles.com/encodec/Diffusion/mbd_comp_{n_q}.pt'
models, processors, cfgs = load_diffusion_models(path, device=device)
path = 'facebook/multiband-diffusion'
filename = f'mbd_comp_{n_q}.pt'
models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
DPs = []
for i in range(len(models)):
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
Expand Down

0 comments on commit 4a4578a

Please sign in to comment.