From 745c6f49085a42e800ee240c884166e55f4dc73e Mon Sep 17 00:00:00 2001 From: skytnt Date: Tue, 12 Sep 2023 10:53:19 +0800 Subject: [PATCH] update app --- app.py | 106 +++++++------ app_onnx.py | 109 +++++++------ javascript/app.js | 389 ++++++++++++++++++++++++++++++++++++++++++++++ midi_tokenizer.py | 13 ++ 4 files changed, 518 insertions(+), 99 deletions(-) create mode 100644 javascript/app.js diff --git a/app.py b/app.py index ec727bf..681554d 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,7 @@ import argparse import glob +import json -import PIL import gradio as gr import numpy as np import torch @@ -15,6 +15,7 @@ from midi_synthesizer import synthesis from huggingface_hub import hf_hub_download + @torch.inference_mode() def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20, disable_patch_change=False, disable_control_change=False, disable_channels=None, amp=True): @@ -82,43 +83,14 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20, break +def create_msg(name, data): + return {"name": name, "data": data} + + def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc, amp): mid_seq = [] - max_len = int(gen_events) - img_len = 1024 - img = np.full((128 * 2, img_len, 3), 255, dtype=np.uint8) - state = {"t1": 0, "t": 0, "cur_pos": 0} - rand = np.random.RandomState(0) - colors = {(i, j): rand.randint(0, 200, 3) for i in range(128) for j in range(16)} - - def draw_event(tokens): - if tokens[0] in tokenizer.id_events: - name = tokenizer.id_events[tokens[0]] - if len(tokens) <= len(tokenizer.events[name]): - return - params = tokens[1:] - params = [params[i] - tokenizer.parameter_ids[p][0] for i, p in enumerate(tokenizer.events[name])] - if not all([0 <= params[i] < tokenizer.event_parameters[p] for i, p in enumerate(tokenizer.events[name])]): - return - event = [name] + params - state["t1"] += event[1] - t = state["t1"] * 16 + event[2] - state["t"] = t - if name == "note": - tr, d, c, p = event[3:7] - shift = t + d - (state["cur_pos"] + img_len) - if shift > 0: - img[:, :-shift] = img[:, shift:] - img[:, -shift:] = 255 - state["cur_pos"] += shift - t = t - state["cur_pos"] - img[p * 2:(p + 1) * 2, t: t + d] = colors[(tr, c)] - - def get_img(): - t = state["t"] - state["cur_pos"] - img_new = img.copy() - img_new[:, t: t + 2] = 0 - return PIL.Image.fromarray(np.flip(img_new, 0)) + gen_events = int(gen_events) + max_len = gen_events disable_patch_change = False disable_channels = None @@ -135,7 +107,7 @@ def get_img(): mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i, c, p])) mid_seq = mid mid = np.asarray(mid, dtype=np.int64) - if len(instruments) > 0 or drum_kit != "None": + if len(instruments) > 0: disable_patch_change = True disable_channels = [i for i in range(16) if i not in patches] elif mid is not None: @@ -144,20 +116,24 @@ def get_img(): mid = mid[:int(midi_events)] max_len += len(mid) for token_seq in mid: - mid_seq.append(token_seq) - draw_event(token_seq) + mid_seq.append(token_seq.tolist()) + + init_msgs = [create_msg("visualizer_clear", None)] + for tokens in mid_seq: + init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens))) + yield mid_seq, None, None, init_msgs generator = generate(mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k, disable_patch_change=disable_patch_change, disable_control_change=not allow_cc, disable_channels=disable_channels, amp=amp) - for token_seq in generator: + for i, token_seq in enumerate(generator): mid_seq.append(token_seq) - draw_event(token_seq) - yield mid_seq, get_img(), None, None + event = tokenizer.tokens2event(token_seq.tolist()) + yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])] mid = tokenizer.detokenize(mid_seq) with open(f"output.mid", 'wb') as f: f.write(MIDI.score2midi(mid)) audio = synthesis(MIDI.score2opus(mid), soundfont_path) - yield mid_seq, get_img(), "output.mid", (44100, audio) + yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)] def cancel_run(mid_seq): @@ -165,7 +141,7 @@ def cancel_run(mid_seq): with open(f"output.mid", 'wb') as f: f.write(MIDI.score2midi(mid)) audio = synthesis(MIDI.score2opus(mid), soundfont_path) - return "output.mid", (44100, audio) + return "output.mid", (44100, audio), [create_msg("visualizer_end", None)] def load_model(path): @@ -181,6 +157,38 @@ def get_model_path(): return model_path_input.update(choices=model_paths) +def load_javascript(dir="javascript"): + scripts_list = glob.glob(f"{dir}/*.js") + javascript = "" + for path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: + javascript += f"\n" + template_response_ori = gr.routes.templates.TemplateResponse + + def template_response(*args, **kwargs): + res = template_response_ori(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gr.routes.templates.TemplateResponse = template_response + + +class JSMsgReceiver(gr.HTML): + + def __init__(self, **kwargs): + super().__init__(elem_id="msg_receiver", visible=False, **kwargs) + + def postprocess(self, y): + if y: + y = f"

{json.dumps(y)}

" + return super().postprocess(y) + + def get_block_name(self) -> str: + return "html" + + number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz", 40: "Blush", 48: "Orchestra"} patch2number = {v: k for k, v in MIDI.Number2patch.items()} @@ -195,8 +203,10 @@ def get_model_path(): tokenizer = MIDITokenizer() model = MIDIModel(tokenizer).to(device=opt.device) + load_javascript() app = gr.Blocks() with app: + js_msg = JSMsgReceiver() with gr.Accordion(label="Model option", open=False): load_model_path_btn = gr.Button("Get Models") model_path_input = gr.Dropdown(label="model") @@ -243,12 +253,12 @@ def get_model_path(): run_btn = gr.Button("generate", variant="primary") stop_btn = gr.Button("stop and output") output_midi_seq = gr.Variable() - output_midi_img = gr.Image(label="output image") + output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container") + output_audio = gr.Audio(label="output audio", format="wav", elem_id="midi_audio") output_midi = gr.File(label="output midi", file_types=[".mid"]) - output_audio = gr.Audio(label="output audio", format="mp3") run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k, input_allow_cc, input_amp], - [output_midi_seq, output_midi_img, output_midi, output_audio]) - stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False) + [output_midi_seq, output_midi, output_audio, js_msg]) + stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False) app.queue(1).launch(server_port=opt.port) diff --git a/app_onnx.py b/app_onnx.py index 73ded62..e788baa 100644 --- a/app_onnx.py +++ b/app_onnx.py @@ -1,12 +1,12 @@ import argparse +import glob +import json import os.path from sys import exit import gradio as gr import numpy as np import onnxruntime as rt -import PIL -import PIL.ImageColor import requests import tqdm @@ -105,44 +105,14 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20, break +def create_msg(name, data): + return {"name": name, "data": data} + + def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc): mid_seq = [] - max_len = int(gen_events) - img_len = 1024 - img = np.full((128 * 2, img_len, 3), 255, dtype=np.uint8) - state = {"t1": 0, "t": 0, "cur_pos": 0} - colors = ['navy', 'blue', 'deepskyblue', 'teal', 'green', 'lightgreen', 'lime', 'orange', - 'brown', 'grey', 'red', 'pink', 'aqua', 'orchid', 'bisque', 'coral'] - colors = [PIL.ImageColor.getrgb(color) for color in colors] - - def draw_event(tokens): - if tokens[0] in tokenizer.id_events: - name = tokenizer.id_events[tokens[0]] - if len(tokens) <= len(tokenizer.events[name]): - return - params = tokens[1:] - params = [params[i] - tokenizer.parameter_ids[p][0] for i, p in enumerate(tokenizer.events[name])] - if not all([0 <= params[i] < tokenizer.event_parameters[p] for i, p in enumerate(tokenizer.events[name])]): - return - event = [name] + params - state["t1"] += event[1] - t = state["t1"] * 16 + event[2] - state["t"] = t - if name == "note": - tr, d, c, p = event[3:7] - shift = t + d - (state["cur_pos"] + img_len) - if shift > 0: - img[:, :-shift] = img[:, shift:] - img[:, -shift:] = 255 - state["cur_pos"] += shift - t = t - state["cur_pos"] - img[p * 2:(p + 1) * 2, t: t + d] = colors[c] - - def get_img(): - t = state["t"] - state["cur_pos"] - img_new = img.copy() - img_new[:, t: t + 2] = 0 - return PIL.Image.fromarray(np.flip(img_new, 0)) + gen_events = int(gen_events) + max_len = gen_events disable_patch_change = False disable_channels = None @@ -168,20 +138,24 @@ def get_img(): mid = mid[:int(midi_events)] max_len += len(mid) for token_seq in mid: - mid_seq.append(token_seq) - draw_event(token_seq) + mid_seq.append(token_seq.tolist()) + + init_msgs = [create_msg("visualizer_clear", None)] + for tokens in mid_seq: + init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens))) + yield mid_seq, None, None, init_msgs generator = generate(mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k, disable_patch_change=disable_patch_change, disable_control_change=not allow_cc, disable_channels=disable_channels) - for token_seq in generator: + for i, token_seq in enumerate(generator): mid_seq.append(token_seq) - draw_event(token_seq) - yield mid_seq, get_img(), None, None + event = tokenizer.tokens2event(token_seq.tolist()) + yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i+1, gen_events])] mid = tokenizer.detokenize(mid_seq) with open(f"output.mid", 'wb') as f: f.write(MIDI.score2midi(mid)) audio = synthesis(MIDI.score2opus(mid), opt.soundfont_path) - yield mid_seq, get_img(), "output.mid", (44100, audio) + yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)] def cancel_run(mid_seq): @@ -191,7 +165,7 @@ def cancel_run(mid_seq): with open(f"output.mid", 'wb') as f: f.write(MIDI.score2midi(mid)) audio = synthesis(MIDI.score2opus(mid), opt.soundfont_path) - return "output.mid", (44100, audio) + return "output.mid", (44100, audio), [create_msg("visualizer_end", None)] def download(url, output_file): @@ -216,6 +190,38 @@ def download_if_not_exit(url, output_file): raise e +def load_javascript(dir="javascript"): + scripts_list = glob.glob(f"{dir}/*.js") + javascript = "" + for path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: + javascript += f"\n" + template_response_ori = gr.routes.templates.TemplateResponse + + def template_response(*args, **kwargs): + res = template_response_ori(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gr.routes.templates.TemplateResponse = template_response + + +class JSMsgReceiver(gr.HTML): + + def __init__(self, **kwargs): + super().__init__(elem_id="msg_receiver", visible=False, **kwargs) + + def postprocess(self, y): + if y: + y = f"

{json.dumps(y)}

" + return super().postprocess(y) + + def get_block_name(self) -> str: + return "html" + + number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz", 40: "Blush", 48: "Orchestra"} patch2number = {v: k for k, v in MIDI.Number2patch.items()} @@ -259,6 +265,7 @@ def download_if_not_exit(url, output_file): input("Failed to load models, maybe you need to delete them and re-download it.\nPress any key to continue...") exit(-1) + load_javascript() app = gr.Blocks() with app: gr.Markdown("

Midi Composer

") @@ -269,7 +276,7 @@ def download_if_not_exit(url, output_file): "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)" " for faster running" ) - + js_msg = JSMsgReceiver() tab_select = gr.Variable(value=0) with gr.Tabs(): with gr.TabItem("instrument prompt") as tab1: @@ -290,7 +297,7 @@ def download_if_not_exit(url, output_file): ], [input_instruments, input_drum_kit]) with gr.TabItem("midi prompt") as tab2: input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary") - input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512, + input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=1024, step=1, value=128) @@ -307,14 +314,14 @@ def download_if_not_exit(url, output_file): run_btn = gr.Button("generate", variant="primary") stop_btn = gr.Button("stop and output") output_midi_seq = gr.Variable() - output_midi_img = gr.Image(label="output image") + output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container") + output_audio = gr.Audio(label="output audio", format="wav", elem_id="midi_audio") output_midi = gr.File(label="output midi", file_types=[".mid"]) - output_audio = gr.Audio(label="output audio", format="wav") run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k, input_allow_cc], - [output_midi_seq, output_midi_img, output_midi, output_audio]) - stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False) + [output_midi_seq, output_midi, output_audio, js_msg]) + stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False) try: port = opt.port if port == -1: diff --git a/javascript/app.js b/javascript/app.js new file mode 100644 index 0000000..0a8d121 --- /dev/null +++ b/javascript/app.js @@ -0,0 +1,389 @@ +function gradioApp() { + const elems = document.getElementsByTagName('gradio-app') + const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot + return !!gradioShadowRoot ? gradioShadowRoot : document; +} + +uiUpdateCallbacks = [] +msgReceiveCallbacks = [] + +function onUiUpdate(callback){ + uiUpdateCallbacks.push(callback) +} + +function onMsgReceive(callback){ + msgReceiveCallbacks.push(callback) +} + +function runCallback(x, m){ + try { + x(m) + } catch (e) { + (console.error || console.log).call(console, e.message, e); + } +} +function executeCallbacks(queue, m) { + queue.forEach(function(x){runCallback(x, m)}) +} + +document.addEventListener("DOMContentLoaded", function() { + var mutationObserver = new MutationObserver(function(m){ + executeCallbacks(uiUpdateCallbacks, m); + }); + mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) +}); + +(()=>{ + let mse_receiver_inited = null + onUiUpdate(()=>{ + let app = gradioApp() + let msg_receiver = app.querySelector("#msg_receiver"); + if(!!msg_receiver && mse_receiver_inited !== msg_receiver){ + let mutationObserver = new MutationObserver(function(ms){ + ms.forEach((m)=>{ + m.addedNodes.forEach((node)=>{ + if(node.nodeName === "P"){ + let obj = JSON.parse(node.innerText); + if(obj instanceof Array){ + obj.forEach((o)=>{executeCallbacks(msgReceiveCallbacks, o);}); + }else{ + executeCallbacks(msgReceiveCallbacks, obj); + } + } + }) + }) + }); + mutationObserver.observe( msg_receiver, {childList:true, subtree:true, characterData:true}) + console.log("receiver init"); + mse_receiver_inited = msg_receiver; + } + }) +})() + +function HSVtoRGB(h, s, v) { + let r, g, b, i, f, p, q, t; + i = Math.floor(h * 6); + f = h * 6 - i; + p = v * (1 - s); + q = v * (1 - f * s); + t = v * (1 - (1 - f) * s); + switch (i % 6) { + case 0: r = v; g = t; b = p; break; + case 1: r = q; g = v; b = p; break; + case 2: r = p; g = v; b = t; break; + case 3: r = p; g = q; b = v; break; + case 4: r = t; g = p; b = v; break; + case 5: r = v; g = p; b = q; break; + } + return { + r: Math.round(r * 255), + g: Math.round(g * 255), + b: Math.round(b * 255) + }; +} + +class MidiVisualizer extends HTMLElement{ + constructor() { + super(); + this.midiEvents = []; + this.activeNotes = []; + this.midiTimes = []; + this.wrapper = null; + this.svg = null; + this.timeLine = null; + this.config = { + noteHeight : 4, + beatWidth: 32 + } + this.timePreBeat = 16 + this.svgWidth = 0; + this.t1 = 0; + this.playTime = 0 + this.playTimeMs = 0 + this.colorMap = new Map(); + this.playing = false; + this.timer = null; + this.init(); + } + + init(){ + this.innerHTML='' + const shadow = this.attachShadow({mode: 'open'}); + const style = document.createElement("style"); + const wrapper = document.createElement('div'); + style.textContent = ".note.active {stroke: black;stroke-width: 0.75;stroke-opacity: 0.75;}"; + wrapper.style.overflowX= "scroll" + const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg'); + svg.style.height = `${this.config.noteHeight*128}px`; + svg.style.width = `${this.svgWidth}px`; + const timeLine = document.createElementNS('http://www.w3.org/2000/svg', 'line'); + timeLine.style.stroke = "green" + timeLine.style.strokeWidth = 2; + shadow.appendChild(style) + shadow.appendChild(wrapper); + wrapper.appendChild(svg); + svg.appendChild(timeLine) + this.wrapper = wrapper; + this.svg = svg; + this.timeLine= timeLine; + this.setPlayTime(0); + } + + clearMidiEvents(){ + this.pause() + this.midiEvents = []; + this.activeNotes = []; + this.midiTimes = []; + this.t1 = 0 + this.colorMap.clear() + this.setPlayTime(0); + this.playTimeMs = 0 + this.svgWidth = 0 + this.svg.innerHTML = '' + this.svg.style.width = `${this.svgWidth}px`; + this.svg.appendChild(this.timeLine) + } + + appendMidiEvent(midiEvent){ + if(midiEvent instanceof Array && midiEvent.length > 0){ + + this.t1 += midiEvent[1] + let t = this.t1*this.timePreBeat + midiEvent[2] + midiEvent = [midiEvent[0], t].concat(midiEvent.slice(3)) + if(midiEvent[0] === "note"){ + let track = midiEvent[2] + let duration = midiEvent[3] + let channel = midiEvent[4] + let pitch = midiEvent[5] + let velocity = midiEvent[6] + let x = (t/this.timePreBeat)*this.config.beatWidth + let y = (127 - pitch)*this.config.noteHeight + let w = (duration/this.timePreBeat)*this.config.beatWidth + let h = this.config.noteHeight + this.svgWidth = Math.ceil(Math.max(x + w, this.svgWidth)) + let color = this.getColor(track, channel) + let opacity = Math.min(1, velocity/127 + 0.1).toFixed(2) + let rect = this.drawNote(x,y,w,h, `rgba(${color.r}, ${color.g}, ${color.b}, ${opacity})`) + midiEvent.push(rect) + this.setPlayTime(t); + this.wrapper.scrollTo(this.svgWidth - this.wrapper.offsetWidth, 0) + } + this.midiEvents.push(midiEvent); + this.svg.style.width = `${this.svgWidth}px`; + } + + } + + getColor(track, channel){ + let key = `${track},${channel}`; + let color = this.colorMap.get(key); + if(!!color){ + return color; + } + color = HSVtoRGB(Math.random(),Math.random()*0.5 + 0.5,1); + this.colorMap.set(key, color); + return color; + } + + drawNote(x, y, w, h, fill) { + if (!this.svg) { + return null; + } + const rect = document.createElementNS('http://www.w3.org/2000/svg', 'rect'); + rect.classList.add('note'); + rect.setAttribute('fill', fill); + // Round values to the nearest integer to avoid partially filled pixels. + rect.setAttribute('x', `${Math.round(x)}`); + rect.setAttribute('y', `${Math.round(y)}`); + rect.setAttribute('width', `${Math.round(w)}`); + rect.setAttribute('height', `${Math.round(h)}`); + this.svg.appendChild(rect); + return rect + } + + finishAppendMidiEvent(){ + this.pause() + let midiEvents = this.midiEvents.sort((a, b)=>a[1]-b[1]) + let tempo = (60 / 120) * 10 ** 3 + let ms = 0 + let lastT = 0 + this.midiTimes.push({ms:ms, t: 0, tempo: tempo}) + midiEvents.forEach((midiEvent)=>{ + let t = midiEvent[1] + ms += ((t- lastT) / this.timePreBeat) * tempo + if(midiEvent[0]==="set_tempo"){ + tempo = (60 / midiEvent[3]) * 10 ** 3 + this.midiTimes.push({ms:ms, t: t, tempo: tempo}) + } + lastT = t + }) + } + + setPlayTime(t){ + this.playTime = t + let x = Math.round((t/this.timePreBeat)*this.config.beatWidth) + this.timeLine.setAttribute('x1', `${x}`); + this.timeLine.setAttribute('y1', '0'); + this.timeLine.setAttribute('x2', `${x}`); + this.timeLine.setAttribute('y2', `${this.config.noteHeight*128}`); + + this.wrapper.scrollTo(Math.max(0, x - this.wrapper.offsetWidth/2), 0) + + if(this.playing){ + let activeNotes = [] + this.removeActiveNotes(this.activeNotes) + this.midiEvents.forEach((midiEvent)=>{ + if(midiEvent[0] === "note"){ + let time = midiEvent[1] + let duration = midiEvent[3] + let note = midiEvent[midiEvent.length - 1] + if(time <=this.playTime && time+duration>= this.playTime){ + activeNotes.push(note) + } + } + }) + this.addActiveNotes(activeNotes) + } + } + + setPlayTimeMs(ms){ + this.playTimeMs = ms + let playTime = 0 + for(let i =0;i=ms){ + break; + } + playTime = midiTime.t + (ms-midiTime.ms) * this.timePreBeat / midiTime.tempo + } + this.setPlayTime(playTime) + } + + addActiveNotes(notes){ + notes.forEach((note)=>{ + this.activeNotes.push(note) + note.classList.add('active'); + }); + } + + removeActiveNotes(notes){ + notes.forEach((note)=>{ + let idx = this.activeNotes.indexOf(note) + if(idx>-1) + this.activeNotes.splice(idx, 1); + note.classList.remove('active'); + }); + } + + play(){ + this.playing = true; + this.timer = setInterval(() => { + this.setPlayTimeMs(this.playTimeMs + 10) + }, 10); + } + + pause(){ + if(!!this.timer) + clearInterval(this.timer) + this.removeActiveNotes(this.activeNotes) + this.timer = null; + this.playing = false; + } + + + bindAudioPlayer(audio){ + this.pause() + audio.addEventListener("play", (event)=>{ + this.play() + }) + audio.addEventListener("pause", (event)=>{ + this.pause() + }) + audio.addEventListener("timeupdate", (event)=>{ + this.setPlayTimeMs(event.target.currentTime*10**3) + }) + } +} + +customElements.define('midi-visualizer', MidiVisualizer); + +(()=>{ + let midi_visualizer_container_inited = null + let midi_audio_inited = null; + let midi_visualizer = document.createElement('midi-visualizer') + onUiUpdate((m)=>{ + let app = gradioApp() + let midi_visualizer_container = app.querySelector("#midi_visualizer_container"); + if(!!midi_visualizer_container && midi_visualizer_container_inited!== midi_visualizer_container){ + midi_visualizer_container.appendChild(midi_visualizer) + midi_visualizer_container_inited = midi_visualizer_container; + } + let midi_audio = app.querySelector("#midi_audio > audio"); + if(!!midi_audio && midi_audio_inited!==midi_audio){ + midi_visualizer.bindAudioPlayer(midi_audio) + midi_audio_inited = midi_audio + } + }) + + function createProgressBar(progressbarContainer){ + let parentProgressbar = progressbarContainer.parentNode; + let divProgress = document.createElement('div'); + divProgress.className='progressDiv'; + let rect = progressbarContainer.getBoundingClientRect(); + divProgress.style.width = rect.width + "px"; + divProgress.style.background = "#b4c0cc"; + divProgress.style.borderRadius = "8px"; + let divInner = document.createElement('div'); + divInner.className='progress'; + divInner.style.color = "white"; + divInner.style.background = "#0060df"; + divInner.style.textAlign = "right"; + divInner.style.fontWeight = "bold"; + divInner.style.borderRadius = "8px"; + divInner.style.height = "20px"; + divInner.style.lineHeight = "20px"; + divInner.style.paddingRight = "8px" + divInner.style.width = "0%"; + divProgress.appendChild(divInner); + parentProgressbar.insertBefore(divProgress, progressbarContainer); + } + + function removeProgressBar(progressbarContainer){ + let parentProgressbar = progressbarContainer.parentNode; + let divProgress = parentProgressbar.querySelector(".progressDiv"); + parentProgressbar.removeChild(divProgress); + } + + function setProgressBar(progressbarContainer, progress, total){ + let parentProgressbar = progressbarContainer.parentNode; + let divProgress = parentProgressbar.querySelector(".progressDiv"); + let divInner = parentProgressbar.querySelector(".progress"); + if(total===0) + total = 1; + divInner.style.width = `${(progress/total)*100}%`; + divInner.textContent = `${progress}/${total}`; + } + + onMsgReceive((msg)=>{ + switch (msg.name) { + case "visualizer_clear": + midi_visualizer.clearMidiEvents(); + createProgressBar(midi_visualizer_container_inited) + break; + case "visualizer_append": + midi_visualizer.appendMidiEvent(msg.data); + break; + case "progress": + let progress = msg.data[0] + let total = msg.data[1] + setProgressBar(midi_visualizer_container_inited, progress, total) + break; + case "visualizer_end": + midi_visualizer.finishAppendMidiEvent() + midi_visualizer.setPlayTime(0); + removeProgressBar(midi_visualizer_container_inited); + break; + default: + } + }) +})(); diff --git a/midi_tokenizer.py b/midi_tokenizer.py index 1335baf..ab287e5 100644 --- a/midi_tokenizer.py +++ b/midi_tokenizer.py @@ -101,6 +101,19 @@ def event2tokens(self, event): tokens += [self.pad_id] * (self.max_token_seq - len(tokens)) return tokens + def tokens2event(self, tokens): + if tokens[0] in self.id_events: + name = self.id_events[tokens[0]] + if len(tokens) <= len(self.events[name]): + return [] + params = tokens[1:] + params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])] + if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]): + return [] + event = [name] + params + return event + return [] + def detokenize(self, midi_seq): ticks_per_beat = 480 tracks_dict = {}