Skip to content

Commit

Permalink
Small fix for finetuning (#118)
Browse files Browse the repository at this point in the history
* Small fixes for finetuning

* fix
  • Loading branch information
loubbrad authored Dec 24, 2024
1 parent 40a245c commit fedf763
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 36 deletions.
6 changes: 1 addition & 5 deletions aria/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def _get_onset_adjusted_msg(

return _temp_note_msg

_note_msgs = midi_dict.note_msgs
_note_msgs = copy.deepcopy(midi_dict.note_msgs)

# Remove notes
if random.random() < config["remove_notes"]["activation_prob"]:
Expand Down Expand Up @@ -1270,10 +1270,6 @@ def _build_epoch(_save_path, _midi_dataset):
if _idx % 250 == 0:
logger.info(f"Finished processing {_idx}")

# DEBUG
if _idx == 1000:
break

logger = setup_logger()
assert max_seq_len > 0, "max_seq_len must be greater than 0"
assert num_epochs > 0, "num_epochs must be greater than 0"
Expand Down
17 changes: 8 additions & 9 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,20 @@ def sample(args):
guidance_midi_dict=guidance_midi_dict,
)

if guidance_seq:
tokenizer.detokenize(guidance_seq).to_midi().save(
os.path.join(samples_dir, f"guidance.mid")
)
if len(prompt_seq) + args.l > model_config.max_seq_len:
print(
"WARNING: Required context exceeds max_seq_len supported by model"
)
prompts = [prompt_seq for _ in range(num_variations)]

if args.cfg is not None:
samples_dir = os.path.join(os.path.dirname(__file__), "..", "samples")
if os.path.isdir(samples_dir) is False:
os.mkdir(samples_dir)
if guidance_seq:
tokenizer.detokenize(guidance_seq).to_midi().save(
os.path.join(samples_dir, f"guidance.mid")
)
if args.cfg is not None and guidance_seq is not None:
results = sample_batch_cfg(
model=model,
tokenizer=tokenizer,
Expand All @@ -186,10 +189,6 @@ def sample(args):
compile=args.compile,
)

samples_dir = os.path.join(os.path.dirname(__file__), "..", "samples")
if os.path.isdir(samples_dir) is False:
os.mkdir(samples_dir)

for idx, tokenized_seq in enumerate(results):
res_midi_dict = tokenizer.detokenize(tokenized_seq)
res_midi = res_midi_dict.to_midi()
Expand Down
6 changes: 5 additions & 1 deletion aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def sample_batch_cfg(
logits_cfg[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float(
"-inf"
)
logits_cfg[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf")

if temperature > 0.0:
probs = torch.softmax(logits_cfg / temperature, dim=-1)
Expand Down Expand Up @@ -389,7 +390,10 @@ def get_inference_prompt(
if tokenizer.dim_tok in prompt_seq:
prompt_seq.remove(tokenizer.dim_tok)

if guidance_midi_dict is not None:
if (
guidance_midi_dict is not None
and tokenizer.guidance_start_tok in prompt_seq
):
guidance_seq = copy.deepcopy(prompt_seq)
guidance_seq = guidance_seq[
: guidance_seq.index(tokenizer.guidance_end_tok)
Expand Down
4 changes: 2 additions & 2 deletions aria/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def detokenize(self, tokenized_seq: list, **kwargs):

def export_data_aug(self):
return [
self.export_guidance_tempo_aug(max_tempo_aug=0.25, mixup=True),
self.export_guidance_pitch_aug(4),
self.export_guidance_tempo_aug(max_tempo_aug=0.2, mixup=True),
self.export_guidance_pitch_aug(3),
self.export_guidance_velocity_aug(2),
]

Expand Down
5 changes: 1 addition & 4 deletions aria/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@
# -bs 32 \
# -workers 8

# TODO:
# - Test that everything works on a distributed setup


def setup_logger(project_dir: str):
# Get logger and reset all handlers
Expand Down Expand Up @@ -196,7 +193,7 @@ def get_optim(
num_epochs: int,
steps_per_epoch: int,
):
LR = 3e-5
LR = 3e-4
END_RATIO = 0.1
WARMUP_STEPS = 200

Expand Down
28 changes: 13 additions & 15 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -182,31 +182,29 @@
"min_clean_interval_ms": 60000,
"max_clean_interval_ms": 200000,
"noising": {
"activation_prob": 0.8,
"activation_prob": 0.5,
"remove_notes": {
"activation_prob": 0.5,
"activation_prob": 0.25,
"min_ratio": 0.0,
"max_ratio": 0.3
"max_ratio": 0.15
},
"adjust_velocity": {
"activation_prob": 0.3,
"activation_prob": 0.25,
"min_adjust": 1,
"max_adjust": 30,
"max_ratio": 0.1,
"min_ratio": 0.30
"max_adjust": 20
},
"adjust_onsets": {
"activation_prob": 0.25,
"min_adjust_s": 0.01,
"max_adjust_s": 0.07,
"max_ratio": 0.15,
"min_ratio": 0.3
"min_adjust_s": 0.005,
"max_adjust_s": 0.05,
"max_ratio": 0.0,
"min_ratio": 0.2
},
"quantize_onsets": {
"activation_prob": 0.15,
"activation_prob": 0.05,
"min_quant_s": 0.05,
"max_quant_s": 0.15,
"max_vel_delta": 45
"max_quant_s": 0.1,
"max_vel_delta": 30
}
}
}
Expand All @@ -215,7 +213,7 @@
"inference_abs": {
"guidance": {
"min_ms": 5000,
"max_ms": 30000
"max_ms": 40000
}


Expand Down

0 comments on commit fedf763

Please sign in to comment.