Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Live player #87

Closed
wants to merge 11 commits into from
4 changes: 1 addition & 3 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,8 +832,6 @@ def _build(_midi_dataset):

_build(_midi_dataset=midi_dataset)

logger.info(
f"Finished building, saved Finetuning to {save_path}"
)
logger.info(f"Finished building, saved Finetuning to {save_path}")

return cls(file_path=save_path, tokenizer=tokenizer)
62 changes: 49 additions & 13 deletions aria/model/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,26 @@

class KVCache(torch.nn.Module):
def __init__(
self, max_batch_size, n_head, d_head, dtype=torch.float16, max_size=8192
self,
max_batch_size,
n_head,
d_head,
dtype=torch.float16,
max_size=8192,
rolling=True,
):
"""
Cache for key-value pairs used in self-attention.
Args:
max_batch_size: the maximum batch size
n_head: the number of heads
d_head: the dimension of each head
dtype: the dtype of the cache
max_size: the maximum number of positions to cache
rolling: whether to roll when it is full
"""
super().__init__()
self.rolling = rolling
self.shape = (max_batch_size, max_size, n_head, d_head)
self.register_buffer(
"k_cache", torch.empty(self.shape, dtype=dtype), persistent=False
Expand All @@ -17,10 +34,22 @@ def __init__(
)
self.next_pos = 0

def _get_tensor(self, cache, start_pos, next_pos):
if self.rolling and next_pos > self.shape[1]:
return torch.cat(
[
cache[:, next_pos % self.shape[1] :],
cache[:, : next_pos % self.shape[1]],
],
dim=1,
)
else:
return cache[:, start_pos:next_pos]

def update(
self,
k,
v,
k: torch.Tensor,
v: torch.Tensor,
pos: Optional[torch.Tensor] = None,
start_pos: int = 0,
max_pos: Optional[int] = None,
Expand All @@ -42,26 +71,33 @@ def update(
due to dynamic shape.
"""
if pos is None:
self.k_cache[
: k.size(0), self.next_pos : self.next_pos + k.size(1)
] = k
self.v_cache[
: v.size(0), self.next_pos : self.next_pos + v.size(1)
] = v
k_pos = torch.arange(
self.next_pos, self.next_pos + k.size(1), device=k.device
)
v_pos = torch.arange(
self.next_pos, self.next_pos + v.size(1), device=v.device
)
if self.rolling:
k_pos = k_pos % self.shape[1]
v_pos = v_pos % self.shape[1]
self.k_cache[: k.size(0), k_pos] = k
self.v_cache[: v.size(0), v_pos] = v
self.next_pos += k.size(1)
else:
assert pos.size(0) == k.size(1)
assert max_pos is not None, (
"Need to pass in `pos.max()` explicitly. "
"Doing `pos.max()` creates massive overhead."
)
if self.rolling:
pos = pos % self.shape[1]
self.k_cache[: k.size(0), pos] = k
self.v_cache[: v.size(0), pos] = v
# Update next_pos using the max entry.
# Note: `self.next_pos = pos.max() + 1` could have worked, but it
# causes the shape to be dynamic and creates a massive overhead.
self.next_pos = max_pos + 1
return (
self.k_cache[: k.size(0), start_pos : self.next_pos],
self.v_cache[: v.size(0), start_pos : self.next_pos],
)

return self._get_tensor(
self.k_cache, start_pos, self.next_pos
), self._get_tensor(self.v_cache, start_pos, self.next_pos)
69 changes: 49 additions & 20 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import os
import re
import sys
import tqdm
import pathlib
import warnings
from queue import Queue
from threading import Thread


# TODO: Implement a way of inferring the tokenizer name automatically
Expand Down Expand Up @@ -52,6 +55,10 @@ def _parse_sample_args():
argp.add_argument(
"-sup", action="store_true", help="suppress fluidsynth", default=False
)
argp.add_argument("-live", action="store_true", help="live playing mode")
argp.add_argument(
"-roll", type=int, help="inference on a rolling window", default=0
)

return argp.parse_args(sys.argv[2:])

Expand Down Expand Up @@ -123,7 +130,7 @@ def sample(args):
from aria.tokenizer import RelTokenizer, AbsTokenizer
from aria.sample import greedy_sample
from aria.data.midi import MidiDict
from aria.utils import midi_to_audio
from aria.utils import midi_to_audio, _play, _ensure_fluidsynth

if not cuda_is_available():
print("CUDA device is not available. Using CPU instead.")
Expand Down Expand Up @@ -226,28 +233,50 @@ def _quantize(module, key, input_shape):
prompts = [prompt_seq for _ in range(num_variations)]

# Sample
results = greedy_sample(
model,
tokenizer,
prompts,
device=device,
force_end=force_end,
max_new_tokens=max_new_tokens,
cfg_gamma=args.cfg,
temperature=args.temp,
)
kwargs = {
"model": model,
"tokenizer": tokenizer,
"prompts": prompts,
"device": device,
"force_end": force_end,
"max_new_tokens": max_new_tokens,
"cfg_gamma": args.cfg,
"temperature": args.temp,
"rolling": args.roll,
}
if args.live:
_ensure_fluidsynth()
input_queue = Queue()

iterator = greedy_sample(
**kwargs,
stream_tokens=True,
verbose=True,
)
pbar = tqdm.tqdm(total=max_new_tokens)
player = Thread(
target=_play, args=(input_queue, args.tok == "rel", pbar)
)
player.start()

for token in iterator:
input_queue.put_nowait(tokenizer.decode(token)[0])
input_queue.put(None)
player.join()
else:
results = greedy_sample(**kwargs)

if os.path.isdir("samples") is False:
os.mkdir("samples")
if os.path.isdir("samples") is False:
os.mkdir("samples")

for idx, tokenized_seq in enumerate(results):
res_midi_dict = tokenizer.detokenize(tokenized_seq)
res_midi = res_midi_dict.to_midi()
res_midi.save(f"samples/res_{idx + 1}.mid")
if args.sup is False:
midi_to_audio(f"samples/:res_{idx + 1}.mid")
for idx, tokenized_seq in enumerate(results):
res_midi_dict = tokenizer.detokenize(tokenized_seq)
res_midi = res_midi_dict.to_midi()
res_midi.save(f"samples/res_{idx + 1}.mid")
if args.sup is False:
midi_to_audio(f"samples/res_{idx + 1}.mid")

print("Results saved to samples/")
print("Results saved to samples/")


def _parse_midi_dataset_args():
Expand Down
87 changes: 54 additions & 33 deletions aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import math
import torch

from typing import List
from typing import List, Iterator
from tqdm import tqdm

from aria.model import TransformerLM
Expand All @@ -18,6 +18,7 @@

# TODO: Add which instruments were detected in the prompt


def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len):
if cfg_mode is None:
return cfg_gamma
Expand Down Expand Up @@ -88,6 +89,12 @@ def _batch_encode(tokenizer, prompts: list[list]) -> torch.Tensor:
return torch.stack([tokenizer.encode(p) for p in prompts], dim=0)


def _process_output(tokens: torch.Tensor, use_cfg: bool) -> torch.Tensor:
if use_cfg:
tokens = tokens[: tokens.size(0) // 2]
return tokens.cpu().view(-1)


# Some good settings:
# temp=0.85, top_p=0.9, cfg_gamma=1.4

Expand All @@ -104,10 +111,13 @@ def greedy_sample(
neg_prompts: List[list] | None = None,
neg_prompt_len: int | None = None,
alpha: float | None = 0.4,
force_end=False,
force_end: bool = False,
temperature: float = 0.85,
top_p: float = 0.9,
):
rolling: int = 0,
stream_tokens: bool = False,
verbose: bool = True,
) -> Iterator[list]:
"""Performs greedy (top_p) autoregressive sampling on a batch of prompts.

Args:
Expand All @@ -133,9 +143,11 @@ def greedy_sample(
force_end (bool, optional): Whether to force the end of the prompt. Defaults to False.
temperature (float, optional): Sampling temperature. Defaults to 0.75.
top_p (float, optional): Parameter for top-p sampling. Defaults to 0.95.

rolling (int, optional): Whether to roll the cache. Defaults to 0 (disabled).
stream_tokens (bool, optional): Whether to stream tokens as a generator. Defaults to False.
verbose (bool, optional): Whether to print progress. Defaults to False.
Returns:
List[list]: The list of samples, decoded by the tokenizer.
Iterator[list]: An iterator of samples, decoded by the tokenizer.
"""
assert tokenizer.return_tensors is True, "tokenizer must return tensors."
device = device or torch.device("cuda")
Expand Down Expand Up @@ -172,7 +184,8 @@ def greedy_sample(
f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}"
)

total_len = prompt_len + max_new_tokens
total_len = prompt_len + max_new_tokens # total length of the sequence
window_len = total_len if not rolling else rolling # rolling window size
tokens = torch.full(
(len(padded_combined_prompts), total_len), pad_id, device=device
)
Expand All @@ -193,24 +206,33 @@ def greedy_sample(
start_pos = prompt_len

past_kv = model.get_cache(
max_batch_size=tokens.size(0), max_len=total_len, device=device
max_batch_size=tokens.size(0), max_len=window_len, device=device
)

next_token = tokens[:, :start_pos]

if stream_tokens:
# Yield the prompt tokens first
for i in range(start_pos):
yield _process_output(
next_token[:, i], use_cfg=cfg_gamma is not None
)
for cur_pos in (
pbar := tqdm(
range(start_pos, total_len),
total=total_len - start_pos,
leave=False,
disable=not verbose,
desc="Token generation progress",
)
):
if cur_pos == start_pos:
token = tokens[:, :start_pos]
if rolling and cfg_gamma is not None:
# Have to use a fixed attn_mask if CFG is used.
# Otherwise, when the rolling window is filled, the both prompts become the same.
mask = attn_mask[:, : min(cur_pos, window_len)]
else:
token = tokens[:, cur_pos - 1 : cur_pos]

logits = model.forward(
token, attn_mask=attn_mask[:, :cur_pos], past_kv=past_kv
)
mask = attn_mask[:, max(0, cur_pos - window_len) : cur_pos]
logits = model.forward(next_token, attn_mask=mask, past_kv=past_kv)
logits = logits[:, -1, :]

if cfg_gamma is not None:
Expand Down Expand Up @@ -255,25 +277,24 @@ def greedy_sample(
if next_token[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]:
dim_tok_inserted[_idx] = True

tokens[:, cur_pos] = next_token

decoded = []
for idx, seq in enumerate(tokens.tolist()):
if cfg_gamma is not None and 2 * idx >= tokens.size(0):
break
# Cut to eos tok if any
try:
seq = seq[: seq.index(eos_id)]
except ValueError:
pass
decoded.append(tokenizer.decode(seq))

for idx, seq in enumerate(decoded):
if tokenizer.eos_tok in seq:
eos_idx = seq.index(tokenizer.eos_tok)
decoded[idx] = seq[:eos_idx]

return decoded
if stream_tokens:
# Yield tokens as they are generated
yield _process_output(next_token, use_cfg=cfg_gamma is not None)
else:
# Update tokens
tokens[:, cur_pos] = next_token
next_token = next_token.unsqueeze(1) # (bsz) -> (bsz, 1)

if not stream_tokens:
for idx, seq in enumerate(tokens.tolist()):
if cfg_gamma is not None and 2 * idx >= tokens.size(0):
break
# Cut to eos tok if any
try:
end = seq.index(eos_id)
yield tokenizer.decode(seq[:end])
except ValueError:
yield tokenizer.decode(seq)


def sample_top_p(probs, p):
Expand Down
Loading