From 754c27cff782a92fab0884c18a96a00217e4d82f Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 17 Dec 2023 07:44:09 +0000 Subject: [PATCH 1/9] implement a live playing mode in sample, using Max's code --- aria/run.py | 54 +++++++++++++++++++++------------- aria/sample.py | 73 +++++++++++++++++++++++++++++----------------- aria/utils.py | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 47 deletions(-) diff --git a/aria/run.py b/aria/run.py index 9c9c6f6..e3c3b6b 100644 --- a/aria/run.py +++ b/aria/run.py @@ -6,6 +6,8 @@ import sys import pathlib import warnings +from queue import Queue +from threading import Thread # TODO: Implement a way of inferring the tokenizer name automatically @@ -52,6 +54,7 @@ def _parse_sample_args(): argp.add_argument( "-sup", action="store_true", help="suppress fluidsynth", default=False ) + argp.add_argument("-live", action="store_true", help="live playing mode") return argp.parse_args(sys.argv[2:]) @@ -125,7 +128,7 @@ def sample(args): from aria.tokenizer import RelTokenizer, AbsTokenizer from aria.sample import greedy_sample from aria.data.midi import MidiDict - from aria.utils import midi_to_audio + from aria.utils import midi_to_audio, _play if not cuda_is_available(): print("CUDA device is not available. Using CPU instead.") @@ -227,28 +230,39 @@ def _quantize(module, key, input_shape): prompts = [prompt_seq for _ in range(num_variations)] # Sample - results = greedy_sample( - model, - tokenizer, - prompts, - device=device, - force_end=force_end, - max_new_tokens=max_new_tokens, - cfg_gamma=args.cfg, - temperature=args.temp, - ) + kwargs = { + "model": model, + "tokenizer": tokenizer, + "prompts": prompts, + "device": device, + "force_end": force_end, + "max_new_tokens": max_new_tokens, + "cfg_gamma": args.cfg, + "temperature": args.temp, + } + if args.live: + input_queue = Queue() + output_queue = Queue() + player = Thread(target=_play, args=(input_queue, output_queue)) + player.start() + for token in greedy_sample(**kwargs, stream_tokens=True, verbose=True): + input_queue.put_nowait(tokenizer.decode(token)[0]) + input_queue.put(None) + player.join() + else: + results = greedy_sample(**kwargs) - if os.path.isdir("samples") is False: - os.mkdir("samples") + if os.path.isdir("samples") is False: + os.mkdir("samples") - for idx, tokenized_seq in enumerate(results): - res_midi_dict = tokenizer.detokenize(tokenized_seq) - res_midi = res_midi_dict.to_midi() - res_midi.save(f"samples/res_{idx + 1}.mid") - if args.sup is False: - midi_to_audio(f"samples/:res_{idx + 1}.mid") + for idx, tokenized_seq in enumerate(results): + res_midi_dict = tokenizer.detokenize(tokenized_seq) + res_midi = res_midi_dict.to_midi() + res_midi.save(f"samples/res_{idx + 1}.mid") + if args.sup is False: + midi_to_audio(f"samples/:res_{idx + 1}.mid") - print("Results saved to samples/") + print("Results saved to samples/") def _parse_midi_dataset_args(): diff --git a/aria/sample.py b/aria/sample.py index 8f88c46..f495699 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -92,6 +92,12 @@ def _batch_encode(tokenizer, prompts: list[list]) -> torch.Tensor: return torch.stack([tokenizer.encode(p) for p in prompts], dim=0) +def _process_output(tokens: torch.Tensor, use_cfg: bool) -> torch.Tensor: + if use_cfg: + tokens = tokens[: tokens.size(0) // 2] + return tokens.cpu().view(-1) + + # Some good settings: # temp=0.85, top_p=0.9, cfg_gamma=1.4 @@ -108,9 +114,11 @@ def greedy_sample( neg_prompts: List[list] | None = None, neg_prompt_len: int | None = None, alpha: float | None = 0.4, - force_end=False, + force_end: bool = False, temperature: float = 0.85, top_p: float = 0.9, + stream_tokens: bool = False, + verbose: bool = False, ): """Performs greedy (top_p) autoregressive sampling on a batch of prompts. @@ -137,7 +145,8 @@ def greedy_sample( force_end (bool, optional): Whether to force the end of the prompt. Defaults to False. temperature (float, optional): Sampling temperature. Defaults to 0.75. top_p (float, optional): Parameter for top-p sampling. Defaults to 0.95. - + stream_tokens (bool, optional): Whether to stream tokens as a generator. Defaults to False. + verbose (bool, optional): Whether to print progress. Defaults to False. Returns: List[list]: The list of samples, decoded by the tokenizer. """ @@ -200,20 +209,23 @@ def greedy_sample( max_batch_size=tokens.size(0), max_len=total_len, device=device ) + next_token = tokens[:, :start_pos] + if stream_tokens: + for i in range(start_pos): + yield _process_output( + next_token[:, i], use_cfg=cfg_gamma is not None + ) + for cur_pos in ( pbar := tqdm( range(start_pos, total_len), total=total_len - start_pos, leave=False, + disable=not verbose, ) ): - if cur_pos == start_pos: - token = tokens[:, :start_pos] - else: - token = tokens[:, cur_pos - 1 : cur_pos] - logits = model.forward( - token, attn_mask=attn_mask[:, :cur_pos], past_kv=past_kv + next_token, attn_mask=attn_mask[:, :cur_pos], past_kv=past_kv ) logits = logits[:, -1, :] @@ -259,25 +271,32 @@ def greedy_sample( if next_token[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: dim_tok_inserted[_idx] = True - tokens[:, cur_pos] = next_token - - decoded = [] - for idx, seq in enumerate(tokens.tolist()): - if cfg_gamma is not None and 2 * idx >= tokens.size(0): - break - # Cut to eos tok if any - try: - seq = seq[: seq.index(eos_id)] - except ValueError: - pass - decoded.append(tokenizer.decode(seq)) - - for idx, seq in enumerate(decoded): - if tokenizer.eos_tok in seq: - eos_idx = seq.index(tokenizer.eos_tok) - decoded[idx] = seq[:eos_idx] - - return decoded + if stream_tokens: + # Yield tokens as they are generated + yield _process_output(next_token, use_cfg=cfg_gamma is not None) + else: + # Update tokens + tokens[:, cur_pos] = next_token + next_token = next_token.unsqueeze(1) # (bsz) -> (bsz, 1) + + if not stream_tokens: + decoded = [] + for idx, seq in enumerate(tokens.tolist()): + if cfg_gamma is not None and 2 * idx >= tokens.size(0): + break + # Cut to eos tok if any + try: + seq = seq[: seq.index(eos_id)] + except ValueError: + pass + decoded.append(tokenizer.decode(seq)) + + for idx, seq in enumerate(decoded): + if tokenizer.eos_tok in seq: + eos_idx = seq.index(tokenizer.eos_tok) + decoded[idx] = seq[:eos_idx] + + return decoded def sample_top_p(probs, p): diff --git a/aria/utils.py b/aria/utils.py index acf1ef6..e8c6155 100644 --- a/aria/utils.py +++ b/aria/utils.py @@ -2,7 +2,10 @@ import os import subprocess +import time + import requests +from multiprocessing import Queue from pydub import AudioSegment @@ -59,3 +62,79 @@ def midi_to_audio(mid_path: str, soundfont_path: str | None = None): print(e) print(f"Saved files: \n{wav_path}\n{mp3_path}") + + +def _get_soundfont(path: str) -> bool: + DOWNLOAD_URL = "https://www.dropbox.com/scl/fi/t8gou8stesm42sc559nzu/DoreMarkYamahaS6-v1.6.sf2?rlkey=28ecl63kkjjmwxrkd6hnzsq8f&dl=1" + # download soundfont if it's not already there + if not os.path.isfile(path): + if not os.path.isdir("fluidsynth"): + os.mkdir("fluidsynth") + print("Downloading soundfont ...") + res = requests.get(url=DOWNLOAD_URL) + if res.status_code == 200: + with open(path, "wb") as file_handle: + file_handle.write(res.content) + print("Download complete") + else: + print(f"Failed to download soundfont: RESPONSE {res.status_code}") + return False + return True + + +def _play(input_queue: Queue, output_queue: Queue): + """ + Run in a separate process and receive tokens and play them with fluidsynth + Credits to @maxreciprocate + """ + SOUNDFONT_PATH = "fluidsynth/DoreMarkYamahaS6-v1.6.sf2" + + if not _get_soundfont(SOUNDFONT_PATH): + return + + import fluidsynth # lazy import + import platform + + fs = fluidsynth.Synth() + if platform.system() == "Linux": + fs.start(driver="pulseaudio") + else: + fs.start() + + sfid = fs.sfload(SOUNDFONT_PATH) + + fs.program_select(0, sfid, 0, 0) + + output_queue.put_nowait(True) + + finish = False + current_note = None + open_notes = {} + while True: + if finish and input_queue.empty(): + output_queue.put_nowait(finish) + finish = False + elif not input_queue.empty(): + m = input_queue.get() + print(m) + if m is None: # exit + break + elif m == "": + finish = True + elif m[0] == "piano": + fs.noteon(0, m[1], m[2]) + current_note = m[1] + elif m[0] == "dur": + if current_note is not None: + open_notes[current_note] = m[1] + current_note = None + elif m[0] == "wait": + time.sleep(m[1] / 1000) + + for note in list(open_notes.keys()): + open_notes[note] -= m[1] + if open_notes[note] <= 0: + del open_notes[note] + fs.noteoff(0, note) + else: + time.sleep(0.1) From 585e6786ea12b3b03c27052050bed6f458570f0f Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 17 Dec 2023 07:47:55 +0000 Subject: [PATCH 2/9] add warning; adjust verbose --- aria/run.py | 2 +- aria/sample.py | 2 +- aria/utils.py | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/aria/run.py b/aria/run.py index e3c3b6b..84d7a9c 100644 --- a/aria/run.py +++ b/aria/run.py @@ -245,7 +245,7 @@ def _quantize(module, key, input_shape): output_queue = Queue() player = Thread(target=_play, args=(input_queue, output_queue)) player.start() - for token in greedy_sample(**kwargs, stream_tokens=True, verbose=True): + for token in greedy_sample(**kwargs, stream_tokens=True, verbose=False): input_queue.put_nowait(tokenizer.decode(token)[0]) input_queue.put(None) player.join() diff --git a/aria/sample.py b/aria/sample.py index f495699..dd72b5e 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -118,7 +118,7 @@ def greedy_sample( temperature: float = 0.85, top_p: float = 0.9, stream_tokens: bool = False, - verbose: bool = False, + verbose: bool = True, ): """Performs greedy (top_p) autoregressive sampling on a batch of prompts. diff --git a/aria/utils.py b/aria/utils.py index e8c6155..9af5e7a 100644 --- a/aria/utils.py +++ b/aria/utils.py @@ -88,6 +88,7 @@ def _play(input_queue: Queue, output_queue: Queue): Credits to @maxreciprocate """ SOUNDFONT_PATH = "fluidsynth/DoreMarkYamahaS6-v1.6.sf2" + _catch_up_warned = False if not _get_soundfont(SOUNDFONT_PATH): return @@ -137,4 +138,7 @@ def _play(input_queue: Queue, output_queue: Queue): del open_notes[note] fs.noteoff(0, note) else: + if not _catch_up_warned: + print("Warning: token generation is falling behind") + _catch_up_warned = True time.sleep(0.1) From b965357935b560fd524007979418fe3f1ab9cc51 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 17 Dec 2023 18:29:54 +0000 Subject: [PATCH 3/9] add support for abs tokenizer for live play --- aria/model/cache.py | 54 +++++++++++++++---- aria/run.py | 20 +++++-- aria/sample.py | 21 +++++--- aria/utils.py | 127 ++++++++++++++++++++++++++++++++++---------- 4 files changed, 175 insertions(+), 47 deletions(-) diff --git a/aria/model/cache.py b/aria/model/cache.py index e60b399..574020b 100644 --- a/aria/model/cache.py +++ b/aria/model/cache.py @@ -5,9 +5,26 @@ class KVCache(torch.nn.Module): def __init__( - self, max_batch_size, n_head, d_head, dtype=torch.float16, max_size=8192 + self, + max_batch_size, + n_head, + d_head, + dtype=torch.float16, + max_size=8192, + rolling=True, ): + """ + Cache for key-value pairs used in self-attention. + Args: + max_batch_size: the maximum batch size + n_head: the number of heads + d_head: the dimension of each head + dtype: the dtype of the cache + max_size: the maximum number of positions to cache + rolling: whether to roll when it is full + """ super().__init__() + self.rolling = rolling self.shape = (max_batch_size, max_size, n_head, d_head) self.register_buffer( "k_cache", torch.empty(self.shape, dtype=dtype), persistent=False @@ -17,6 +34,18 @@ def __init__( ) self.next_pos = 0 + def _get_tensor(self, cache, start_pos, next_pos): + if self.rolling and next_pos > self.shape[1]: + return torch.cat( + [ + cache[:, next_pos % self.shape[1] :], + cache[:, : next_pos % self.shape[1]], + ], + dim=1, + ) + else: + return cache[:, start_pos:next_pos] + def update( self, k, @@ -42,12 +71,13 @@ def update( due to dynamic shape. """ if pos is None: - self.k_cache[ - : k.size(0), self.next_pos : self.next_pos + k.size(1) - ] = k - self.v_cache[ - : v.size(0), self.next_pos : self.next_pos + v.size(1) - ] = v + k_pos = torch.arange(self.next_pos, self.next_pos + k.size(1)) + v_pos = torch.arange(self.next_pos, self.next_pos + v.size(1)) + if self.rolling: + k_pos = k_pos % self.shape[1] + v_pos = v_pos % self.shape[1] + self.k_cache[: k.size(0), k_pos] = k + self.v_cache[: v.size(0), v_pos] = v self.next_pos += k.size(1) else: assert pos.size(0) == k.size(1) @@ -55,13 +85,15 @@ def update( "Need to pass in `pos.max()` explicitly. " "Doing `pos.max()` creates massive overhead." ) + if self.rolling: + pos = pos % self.shape[1] self.k_cache[: k.size(0), pos] = k self.v_cache[: v.size(0), pos] = v # Update next_pos using the max entry. # Note: `self.next_pos = pos.max() + 1` could have worked, but it # causes the shape to be dynamic and creates a massive overhead. self.next_pos = max_pos + 1 - return ( - self.k_cache[: k.size(0), start_pos : self.next_pos], - self.v_cache[: v.size(0), start_pos : self.next_pos], - ) + + return self._get_tensor( + self.k_cache, start_pos, self.next_pos + ), self._get_tensor(self.v_cache, start_pos, self.next_pos) diff --git a/aria/run.py b/aria/run.py index 84d7a9c..b217eba 100644 --- a/aria/run.py +++ b/aria/run.py @@ -4,6 +4,7 @@ import os import re import sys +import tqdm import pathlib import warnings from queue import Queue @@ -55,6 +56,9 @@ def _parse_sample_args(): "-sup", action="store_true", help="suppress fluidsynth", default=False ) argp.add_argument("-live", action="store_true", help="live playing mode") + argp.add_argument( + "-roll", type=int, help="inference on a rolling window", default=0 + ) return argp.parse_args(sys.argv[2:]) @@ -239,13 +243,23 @@ def _quantize(module, key, input_shape): "max_new_tokens": max_new_tokens, "cfg_gamma": args.cfg, "temperature": args.temp, + "rolling": args.roll, } if args.live: input_queue = Queue() - output_queue = Queue() - player = Thread(target=_play, args=(input_queue, output_queue)) + + iterator = greedy_sample( + **kwargs, + stream_tokens=True, + verbose=True, + ) + pbar = tqdm.tqdm(total=max_new_tokens) + player = Thread( + target=_play, args=(input_queue, args.tok == "rel", pbar) + ) player.start() - for token in greedy_sample(**kwargs, stream_tokens=True, verbose=False): + + for token in iterator: input_queue.put_nowait(tokenizer.decode(token)[0]) input_queue.put(None) player.join() diff --git a/aria/sample.py b/aria/sample.py index dd72b5e..16b90e1 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -117,6 +117,7 @@ def greedy_sample( force_end: bool = False, temperature: float = 0.85, top_p: float = 0.9, + rolling: int = 0, stream_tokens: bool = False, verbose: bool = True, ): @@ -145,6 +146,7 @@ def greedy_sample( force_end (bool, optional): Whether to force the end of the prompt. Defaults to False. temperature (float, optional): Sampling temperature. Defaults to 0.75. top_p (float, optional): Parameter for top-p sampling. Defaults to 0.95. + rolling (int, optional): Whether to roll the cache. Defaults to 0 (disabled). stream_tokens (bool, optional): Whether to stream tokens as a generator. Defaults to False. verbose (bool, optional): Whether to print progress. Defaults to False. Returns: @@ -185,7 +187,8 @@ def greedy_sample( f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" ) - total_len = prompt_len + max_new_tokens + total_len = prompt_len + max_new_tokens # total length of the sequence + window_len = total_len if not rolling else rolling # rolling window size tokens = torch.full( (len(padded_combined_prompts), total_len), pad_id, device=device ) @@ -206,27 +209,33 @@ def greedy_sample( start_pos = prompt_len past_kv = model.get_cache( - max_batch_size=tokens.size(0), max_len=total_len, device=device + max_batch_size=tokens.size(0), max_len=window_len, device=device ) next_token = tokens[:, :start_pos] + if stream_tokens: + # Yield the prompt tokens first for i in range(start_pos): yield _process_output( next_token[:, i], use_cfg=cfg_gamma is not None ) - for cur_pos in ( pbar := tqdm( range(start_pos, total_len), total=total_len - start_pos, leave=False, disable=not verbose, + desc="Token generation progress", ) ): - logits = model.forward( - next_token, attn_mask=attn_mask[:, :cur_pos], past_kv=past_kv - ) + if rolling and cfg_gamma is not None: + # Have to use a fixed attn_mask if CFG is used. + # Otherwise, when the rolling window is filled, the both prompts become the same. + mask = attn_mask[:, : min(cur_pos, window_len)] + else: + mask = attn_mask[:, max(0, cur_pos - window_len) : cur_pos] + logits = model.forward(next_token, attn_mask=mask, past_kv=past_kv) logits = logits[:, -1, :] if cfg_gamma is not None: diff --git a/aria/utils.py b/aria/utils.py index 9af5e7a..f12a2a0 100644 --- a/aria/utils.py +++ b/aria/utils.py @@ -3,6 +3,7 @@ import os import subprocess import time +import tqdm import requests from multiprocessing import Queue @@ -82,10 +83,76 @@ def _get_soundfont(path: str) -> bool: return True -def _play(input_queue: Queue, output_queue: Queue): +def _handle_absolute( + m, fs, current_note, last_time, open_notes, abs_time_step_ms=5000, **kwargs +): + if m[0] == "piano": + current_note = m + elif m[0] == "dur": + if current_note is not None: + open_notes[current_note] = m[1] + current_note = None + elif m[0] == "onset": + if m[1] < last_time: + raise ValueError( + "Sequence corrupted! Onset time is greater than current time" + ) + time.sleep( + (m[1] - last_time) / 1000 + ) # First wait for the time to update + if current_note is None or current_note[0] != "piano": + raise ValueError("Sequence corrupted! No current note to play") + fs.noteon( + 0, current_note[1], current_note[2] + ) # Next play the previously received note + + for note in list(open_notes.keys()): + open_notes[note] -= m[1] - last_time + if open_notes[note] <= 0: + del open_notes[note] + fs.noteoff(0, note[1]) + + last_time = m[1] + elif m == "": + time.sleep(max(0, abs_time_step_ms - last_time) / 1000) + last_time = 0 + + return current_note, last_time + + +def _handle_relative(m, fs, current_note, open_notes, **kwargs): + if m[0] == "piano": + fs.noteon(0, m[1], m[2]) + current_note = m[1] + elif m[0] == "dur": + if current_note is not None: + open_notes[current_note] = m[1] + current_note = None + elif m[0] == "wait": + time.sleep(m[1] / 1000) + + for note in list(open_notes.keys()): + open_notes[note] -= m[1] + if open_notes[note] <= 0: + del open_notes[note] + fs.noteoff(0, note) + + return current_note + + +def _play( + input_queue: Queue, + is_relative: bool = False, + pbar=None, + abs_time_step_ms=5000, +): """ Run in a separate process and receive tokens and play them with fluidsynth Credits to @maxreciprocate + Args: + input_queue: queue to receive tokens from + is_relative: whether the tokens are relative or absolute + pbar: tqdm progress bar """ SOUNDFONT_PATH = "fluidsynth/DoreMarkYamahaS6-v1.6.sf2" _catch_up_warned = False @@ -106,37 +173,43 @@ def _play(input_queue: Queue, output_queue: Queue): fs.program_select(0, sfid, 0, 0) - output_queue.put_nowait(True) - - finish = False current_note = None + # Previous onset time (effective only when is_relative is False) + last_time = 0 + # Currently open notes open_notes = {} + # The rolling log of the past 5 tokens + rolling_cache, rolling_index = [""] * 5, 0 while True: - if finish and input_queue.empty(): - output_queue.put_nowait(finish) - finish = False - elif not input_queue.empty(): + if not input_queue.empty(): m = input_queue.get() - print(m) - if m is None: # exit + if pbar is not None: + rolling_cache[rolling_index] = ( + m[0] if isinstance(m, tuple) else m + ) + pbar.update(1) + pbar.set_description( + "Past tokens:" + + ", ".join( + rolling_cache[rolling_index:] + + rolling_cache[:rolling_index] + ) + ) + + rolling_index = (rolling_index + 1) % len(rolling_cache) + if m is None or m == "": # exit break - elif m == "": - finish = True - elif m[0] == "piano": - fs.noteon(0, m[1], m[2]) - current_note = m[1] - elif m[0] == "dur": - if current_note is not None: - open_notes[current_note] = m[1] - current_note = None - elif m[0] == "wait": - time.sleep(m[1] / 1000) - - for note in list(open_notes.keys()): - open_notes[note] -= m[1] - if open_notes[note] <= 0: - del open_notes[note] - fs.noteoff(0, note) + elif is_relative: + current_note = _handle_relative(m, fs, current_note, open_notes) + else: + current_note, last_time = _handle_absolute( + m, + fs, + current_note, + last_time, + open_notes, + abs_time_step_ms=abs_time_step_ms, + ) else: if not _catch_up_warned: print("Warning: token generation is falling behind") From 98774ee2525b1473582120d0976f9e4238e95bae Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 17 Dec 2023 18:42:59 +0000 Subject: [PATCH 4/9] remove gradent checkpointing for sampling --- aria/run.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aria/run.py b/aria/run.py index b217eba..d3a2918 100644 --- a/aria/run.py +++ b/aria/run.py @@ -161,6 +161,7 @@ def sample(args): model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) + model_config.grad_checkpoint = False model = TransformerLM(model_config).to(device) if args.trunc + args.l > model_config.max_seq_len: From 8eaf8b485af81f590789b202600fbb0790ba324b Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 17 Dec 2023 20:01:56 +0000 Subject: [PATCH 5/9] fix bug --- aria/run.py | 2 +- aria/sample.py | 20 ++++++-------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/aria/run.py b/aria/run.py index d3a2918..6249e03 100644 --- a/aria/run.py +++ b/aria/run.py @@ -275,7 +275,7 @@ def _quantize(module, key, input_shape): res_midi = res_midi_dict.to_midi() res_midi.save(f"samples/res_{idx + 1}.mid") if args.sup is False: - midi_to_audio(f"samples/:res_{idx + 1}.mid") + midi_to_audio(f"samples/res_{idx + 1}.mid") print("Results saved to samples/") diff --git a/aria/sample.py b/aria/sample.py index 16b90e1..f81c677 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -9,7 +9,7 @@ import math import torch -from typing import List +from typing import List, Iterator from tqdm import tqdm from aria.model import TransformerLM @@ -120,7 +120,7 @@ def greedy_sample( rolling: int = 0, stream_tokens: bool = False, verbose: bool = True, -): +) -> Iterator[list]: """Performs greedy (top_p) autoregressive sampling on a batch of prompts. Args: @@ -150,7 +150,7 @@ def greedy_sample( stream_tokens (bool, optional): Whether to stream tokens as a generator. Defaults to False. verbose (bool, optional): Whether to print progress. Defaults to False. Returns: - List[list]: The list of samples, decoded by the tokenizer. + Iterator[list]: An iterator of samples, decoded by the tokenizer. """ assert tokenizer.return_tensors is True, "tokenizer must return tensors." device = device or torch.device("cuda") @@ -289,23 +289,15 @@ def greedy_sample( next_token = next_token.unsqueeze(1) # (bsz) -> (bsz, 1) if not stream_tokens: - decoded = [] for idx, seq in enumerate(tokens.tolist()): if cfg_gamma is not None and 2 * idx >= tokens.size(0): break # Cut to eos tok if any try: - seq = seq[: seq.index(eos_id)] + end = seq.index(eos_id) + yield tokenizer.decode(seq[:end]) except ValueError: - pass - decoded.append(tokenizer.decode(seq)) - - for idx, seq in enumerate(decoded): - if tokenizer.eos_tok in seq: - eos_idx = seq.index(tokenizer.eos_tok) - decoded[idx] = seq[:eos_idx] - - return decoded + yield tokenizer.decode(seq) def sample_top_p(probs, p): From bebfa106c6ebc8c82d49bd991894c58607ed31a6 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 17 Dec 2023 20:46:51 +0000 Subject: [PATCH 6/9] minor --- aria/model/cache.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/aria/model/cache.py b/aria/model/cache.py index 574020b..f2a928d 100644 --- a/aria/model/cache.py +++ b/aria/model/cache.py @@ -48,8 +48,8 @@ def _get_tensor(self, cache, start_pos, next_pos): def update( self, - k, - v, + k: torch.Tensor, + v: torch.Tensor, pos: Optional[torch.Tensor] = None, start_pos: int = 0, max_pos: Optional[int] = None, @@ -71,8 +71,12 @@ def update( due to dynamic shape. """ if pos is None: - k_pos = torch.arange(self.next_pos, self.next_pos + k.size(1)) - v_pos = torch.arange(self.next_pos, self.next_pos + v.size(1)) + k_pos = torch.arange( + self.next_pos, self.next_pos + k.size(1), device=k.device + ) + v_pos = torch.arange( + self.next_pos, self.next_pos + v.size(1), device=v.device + ) if self.rolling: k_pos = k_pos % self.shape[1] v_pos = v_pos % self.shape[1] From f5de61fbadfbe0e9486f22cf06da056f6e5957d9 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sat, 23 Dec 2023 05:41:18 +0000 Subject: [PATCH 7/9] fix test; style --- aria/data/datasets.py | 4 +--- aria/sample.py | 1 + tests/test_models.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/aria/data/datasets.py b/aria/data/datasets.py index b9da67c..cc208cd 100644 --- a/aria/data/datasets.py +++ b/aria/data/datasets.py @@ -832,8 +832,6 @@ def _build(_midi_dataset): _build(_midi_dataset=midi_dataset) - logger.info( - f"Finished building, saved Finetuning to {save_path}" - ) + logger.info(f"Finished building, saved Finetuning to {save_path}") return cls(file_path=save_path, tokenizer=tokenizer) diff --git a/aria/sample.py b/aria/sample.py index 76a3053..78d4f9d 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -18,6 +18,7 @@ # TODO: Add which instruments were detected in the prompt + def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len): if cfg_mode is None: return cfg_gamma diff --git a/tests/test_models.py b/tests/test_models.py index 1e2ed98..8f89994 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -92,7 +92,7 @@ def test_generation(self): device=torch.device("cpu"), max_new_tokens=50, ) - assert [u == v for u, v in zip(out, out2[1:])] + assert [u == v for u, v in zip(out, list(out2)[1:])] if __name__ == "__main__": From 368f2568da6b6895edf07401e719c70a7a62d41d Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 24 Dec 2023 09:29:08 -0500 Subject: [PATCH 8/9] improve error message --- aria/run.py | 3 ++- aria/utils.py | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/aria/run.py b/aria/run.py index e7d673c..4d4da1b 100644 --- a/aria/run.py +++ b/aria/run.py @@ -130,7 +130,7 @@ def sample(args): from aria.tokenizer import RelTokenizer, AbsTokenizer from aria.sample import greedy_sample from aria.data.midi import MidiDict - from aria.utils import midi_to_audio, _play + from aria.utils import midi_to_audio, _play, _ensure_fluidsynth if not cuda_is_available(): print("CUDA device is not available. Using CPU instead.") @@ -245,6 +245,7 @@ def _quantize(module, key, input_shape): "rolling": args.roll, } if args.live: + _ensure_fluidsynth() input_queue = Queue() iterator = greedy_sample( diff --git a/aria/utils.py b/aria/utils.py index f12a2a0..665f1b9 100644 --- a/aria/utils.py +++ b/aria/utils.py @@ -1,6 +1,7 @@ """Contains miscellaneous utilities""" import os +import platform import subprocess import time import tqdm @@ -161,9 +162,8 @@ def _play( return import fluidsynth # lazy import - import platform - fs = fluidsynth.Synth() + if platform.system() == "Linux": fs.start(driver="pulseaudio") else: @@ -215,3 +215,16 @@ def _play( print("Warning: token generation is falling behind") _catch_up_warned = True time.sleep(0.1) + + +def _ensure_fluidsynth(): + try: + import fluidsynth # lazy import + fs = fluidsynth.Synth() + except Exception as e: + msg = ( + f"\nError in loading pyfluidsynth library. Possible solutions are:\n" + f" - install `pyfluidsynth` (pip install pyfluidsynth)\n" + f" - make sure gcc is in the system (e.g. conda install -c conda-forge gcc)\n" + ) + raise ImportError(msg) from e From 63da21273d347be09f536febe0b10d6cbb36837b Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 24 Dec 2023 09:30:52 -0500 Subject: [PATCH 9/9] format --- aria/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aria/utils.py b/aria/utils.py index 665f1b9..e2e95ff 100644 --- a/aria/utils.py +++ b/aria/utils.py @@ -162,6 +162,7 @@ def _play( return import fluidsynth # lazy import + fs = fluidsynth.Synth() if platform.system() == "Linux": @@ -220,6 +221,7 @@ def _play( def _ensure_fluidsynth(): try: import fluidsynth # lazy import + fs = fluidsynth.Synth() except Exception as e: msg = (