-
Notifications
You must be signed in to change notification settings - Fork 75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Talkin' 'bout my generation #8
Comments
We don't have a specific script for it, but it wouldn't be hard to write it - you just need to create a |
Right. Sometimes the latest checkpoint or checkpoint with lowest training loss doesn't actually generate good audio (unconditionally). It might get "trapped" in an attractor state. As part of the Dadabots process we listen to the short sequences to find a good checkpoint before generating longer sequences. |
Hi, I'm currently trying to write a script to generate audio from a saved checkpoint. Could I possibly get some more specific instructions on how to go about this? I should clarify, I currently have a script, but I'm getting an error when loading the checkpoint. It says "Missing key(s) in state_dict: " then lists a bunch of keys, and then says "Unexpected key(s) in state_dict: " then also lists a bunch of keys. Here is a pastebin link to the exact output (https://pastebin.com/cFNrnr7e) EDIT: I've fixed the error, loading the checkpoint put "model." in front of everything so that had to be trimmed out:
New issue, the .wav file that I'm generating and saving can't be played back by anything EDIT: Disregard, I solved my issue. |
My solution for this is described in the following steps (and implemented in this fork): 1. Save the model parameters in a JSON file in train.pydef main(exp, frame_sizes, dataset, **params):
params = dict(
default_params,
exp=exp, frame_sizes=frame_sizes, dataset=dataset,
**params
)
import json
with open(os.path.join(results_path, 'sample_rnn_params.json'), 'w') as fp:
json.dump(params, fp, sort_keys=True, indent=4)
... 2. Add another register function in the GeneratorPlugin class in plugins.py to accept the trained model and the cuda setting as inputsclass GeneratorPlugin(Plugin):
...
def register(self, trainer):
self.generate = Generator(trainer.model.model, trainer.cuda)
def register_generate(self, model, cuda):
self.generate = Generator(model, cuda)
... 3. Create another python script, say generate_audio.py, that is able to generate new audio from trained modelfrom model import SampleRNN
import torch
from collections import OrderedDict
import os
import json
from trainer.plugins import GeneratorPlugin
# Paths
RESULTS_PATH = 'results/exp:TEST-frame_sizes:16,4-n_rnn:2-piano/'
PRETRAINED_PATH = RESULTS_PATH + 'checkpoints/best-ep65-it79430'
GENERATED_PATH = RESULTS_PATH + 'generated/'
if not os.path.exists(GENERATED_PATH):
os.mkdir(GENERATED_PATH)
# Load model parameters from .json for audio generation
params_path = RESULTS_PATH + 'sample_rnn_params.json'
with open(params_path, 'r') as fp:
params = json.load(fp)
# Create model with same parameters as used in training
model = SampleRNN(
frame_sizes=params['frame_sizes'],
n_rnn=params['n_rnn'],
dim=params['dim'],
learn_h0=params['learn_h0'],
q_levels=params['q_levels'],
weight_norm=params['weight_norm']
)
# Delete "model." from key names since loading the checkpoint automatically attaches it to the key names
pretrained_state = torch.load(PRETRAINED_PATH)
new_pretrained_state = OrderedDict()
for k, v in pretrained_state.items():
layer_name = k.replace("model.", "")
new_pretrained_state[layer_name] = v
# print("k: {}, layer_name: {}, v: {}".format(k, layer_name, np.shape(v)))
# Load pretrained model
model.load_state_dict(new_pretrained_state)
# Generate Plugin
generator = GeneratorPlugin(GENERATED_PATH, params['n_samples'], params['sample_length'], params['sample_rate'])
# Call new register function to accept the trained model and the cuda setting
generator.register_generate(model.cuda(), params['cuda'])
# Generate new audio
generator.epoch('Test') P.S.: Thank you @kurah for the Unexpected keys error solution presented above. |
@gcunhase just curious if you or anyone here has a method to supply your own "seed input" for generation. As in, I want to supply some new novel input and see what it generates from that. |
Thanks for this code contribution!
Is there a way to just generate samples based on a given checkpoint without training?
The Generator is buried in the trainer code and teasing it out looks daunting.
Best,
- lonce
The text was updated successfully, but these errors were encountered: