diff --git a/app.py b/app.py
index ec727bf..681554d 100644
--- a/app.py
+++ b/app.py
@@ -1,7 +1,7 @@
import argparse
import glob
+import json
-import PIL
import gradio as gr
import numpy as np
import torch
@@ -15,6 +15,7 @@
from midi_synthesizer import synthesis
from huggingface_hub import hf_hub_download
+
@torch.inference_mode()
def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
disable_patch_change=False, disable_control_change=False, disable_channels=None, amp=True):
@@ -82,43 +83,14 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
break
+def create_msg(name, data):
+ return {"name": name, "data": data}
+
+
def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc, amp):
mid_seq = []
- max_len = int(gen_events)
- img_len = 1024
- img = np.full((128 * 2, img_len, 3), 255, dtype=np.uint8)
- state = {"t1": 0, "t": 0, "cur_pos": 0}
- rand = np.random.RandomState(0)
- colors = {(i, j): rand.randint(0, 200, 3) for i in range(128) for j in range(16)}
-
- def draw_event(tokens):
- if tokens[0] in tokenizer.id_events:
- name = tokenizer.id_events[tokens[0]]
- if len(tokens) <= len(tokenizer.events[name]):
- return
- params = tokens[1:]
- params = [params[i] - tokenizer.parameter_ids[p][0] for i, p in enumerate(tokenizer.events[name])]
- if not all([0 <= params[i] < tokenizer.event_parameters[p] for i, p in enumerate(tokenizer.events[name])]):
- return
- event = [name] + params
- state["t1"] += event[1]
- t = state["t1"] * 16 + event[2]
- state["t"] = t
- if name == "note":
- tr, d, c, p = event[3:7]
- shift = t + d - (state["cur_pos"] + img_len)
- if shift > 0:
- img[:, :-shift] = img[:, shift:]
- img[:, -shift:] = 255
- state["cur_pos"] += shift
- t = t - state["cur_pos"]
- img[p * 2:(p + 1) * 2, t: t + d] = colors[(tr, c)]
-
- def get_img():
- t = state["t"] - state["cur_pos"]
- img_new = img.copy()
- img_new[:, t: t + 2] = 0
- return PIL.Image.fromarray(np.flip(img_new, 0))
+ gen_events = int(gen_events)
+ max_len = gen_events
disable_patch_change = False
disable_channels = None
@@ -135,7 +107,7 @@ def get_img():
mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i, c, p]))
mid_seq = mid
mid = np.asarray(mid, dtype=np.int64)
- if len(instruments) > 0 or drum_kit != "None":
+ if len(instruments) > 0:
disable_patch_change = True
disable_channels = [i for i in range(16) if i not in patches]
elif mid is not None:
@@ -144,20 +116,24 @@ def get_img():
mid = mid[:int(midi_events)]
max_len += len(mid)
for token_seq in mid:
- mid_seq.append(token_seq)
- draw_event(token_seq)
+ mid_seq.append(token_seq.tolist())
+
+ init_msgs = [create_msg("visualizer_clear", None)]
+ for tokens in mid_seq:
+ init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
+ yield mid_seq, None, None, init_msgs
generator = generate(mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
disable_channels=disable_channels, amp=amp)
- for token_seq in generator:
+ for i, token_seq in enumerate(generator):
mid_seq.append(token_seq)
- draw_event(token_seq)
- yield mid_seq, get_img(), None, None
+ event = tokenizer.tokens2event(token_seq.tolist())
+ yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])]
mid = tokenizer.detokenize(mid_seq)
with open(f"output.mid", 'wb') as f:
f.write(MIDI.score2midi(mid))
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
- yield mid_seq, get_img(), "output.mid", (44100, audio)
+ yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
def cancel_run(mid_seq):
@@ -165,7 +141,7 @@ def cancel_run(mid_seq):
with open(f"output.mid", 'wb') as f:
f.write(MIDI.score2midi(mid))
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
- return "output.mid", (44100, audio)
+ return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
def load_model(path):
@@ -181,6 +157,38 @@ def get_model_path():
return model_path_input.update(choices=model_paths)
+def load_javascript(dir="javascript"):
+ scripts_list = glob.glob(f"{dir}/*.js")
+ javascript = ""
+ for path in scripts_list:
+ with open(path, "r", encoding="utf8") as jsfile:
+ javascript += f"\n"
+ template_response_ori = gr.routes.templates.TemplateResponse
+
+ def template_response(*args, **kwargs):
+ res = template_response_ori(*args, **kwargs)
+ res.body = res.body.replace(
+ b'', f'{javascript}'.encode("utf8"))
+ res.init_headers()
+ return res
+
+ gr.routes.templates.TemplateResponse = template_response
+
+
+class JSMsgReceiver(gr.HTML):
+
+ def __init__(self, **kwargs):
+ super().__init__(elem_id="msg_receiver", visible=False, **kwargs)
+
+ def postprocess(self, y):
+ if y:
+ y = f"
{json.dumps(y)}
"
+ return super().postprocess(y)
+
+ def get_block_name(self) -> str:
+ return "html"
+
+
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
40: "Blush", 48: "Orchestra"}
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
@@ -195,8 +203,10 @@ def get_model_path():
tokenizer = MIDITokenizer()
model = MIDIModel(tokenizer).to(device=opt.device)
+ load_javascript()
app = gr.Blocks()
with app:
+ js_msg = JSMsgReceiver()
with gr.Accordion(label="Model option", open=False):
load_model_path_btn = gr.Button("Get Models")
model_path_input = gr.Dropdown(label="model")
@@ -243,12 +253,12 @@ def get_model_path():
run_btn = gr.Button("generate", variant="primary")
stop_btn = gr.Button("stop and output")
output_midi_seq = gr.Variable()
- output_midi_img = gr.Image(label="output image")
+ output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
+ output_audio = gr.Audio(label="output audio", format="wav", elem_id="midi_audio")
output_midi = gr.File(label="output midi", file_types=[".mid"])
- output_audio = gr.Audio(label="output audio", format="mp3")
run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events,
input_gen_events, input_temp, input_top_p, input_top_k,
input_allow_cc, input_amp],
- [output_midi_seq, output_midi_img, output_midi, output_audio])
- stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
+ [output_midi_seq, output_midi, output_audio, js_msg])
+ stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
app.queue(1).launch(server_port=opt.port)
diff --git a/app_onnx.py b/app_onnx.py
index 73ded62..e788baa 100644
--- a/app_onnx.py
+++ b/app_onnx.py
@@ -1,12 +1,12 @@
import argparse
+import glob
+import json
import os.path
from sys import exit
import gradio as gr
import numpy as np
import onnxruntime as rt
-import PIL
-import PIL.ImageColor
import requests
import tqdm
@@ -105,44 +105,14 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
break
+def create_msg(name, data):
+ return {"name": name, "data": data}
+
+
def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
mid_seq = []
- max_len = int(gen_events)
- img_len = 1024
- img = np.full((128 * 2, img_len, 3), 255, dtype=np.uint8)
- state = {"t1": 0, "t": 0, "cur_pos": 0}
- colors = ['navy', 'blue', 'deepskyblue', 'teal', 'green', 'lightgreen', 'lime', 'orange',
- 'brown', 'grey', 'red', 'pink', 'aqua', 'orchid', 'bisque', 'coral']
- colors = [PIL.ImageColor.getrgb(color) for color in colors]
-
- def draw_event(tokens):
- if tokens[0] in tokenizer.id_events:
- name = tokenizer.id_events[tokens[0]]
- if len(tokens) <= len(tokenizer.events[name]):
- return
- params = tokens[1:]
- params = [params[i] - tokenizer.parameter_ids[p][0] for i, p in enumerate(tokenizer.events[name])]
- if not all([0 <= params[i] < tokenizer.event_parameters[p] for i, p in enumerate(tokenizer.events[name])]):
- return
- event = [name] + params
- state["t1"] += event[1]
- t = state["t1"] * 16 + event[2]
- state["t"] = t
- if name == "note":
- tr, d, c, p = event[3:7]
- shift = t + d - (state["cur_pos"] + img_len)
- if shift > 0:
- img[:, :-shift] = img[:, shift:]
- img[:, -shift:] = 255
- state["cur_pos"] += shift
- t = t - state["cur_pos"]
- img[p * 2:(p + 1) * 2, t: t + d] = colors[c]
-
- def get_img():
- t = state["t"] - state["cur_pos"]
- img_new = img.copy()
- img_new[:, t: t + 2] = 0
- return PIL.Image.fromarray(np.flip(img_new, 0))
+ gen_events = int(gen_events)
+ max_len = gen_events
disable_patch_change = False
disable_channels = None
@@ -168,20 +138,24 @@ def get_img():
mid = mid[:int(midi_events)]
max_len += len(mid)
for token_seq in mid:
- mid_seq.append(token_seq)
- draw_event(token_seq)
+ mid_seq.append(token_seq.tolist())
+
+ init_msgs = [create_msg("visualizer_clear", None)]
+ for tokens in mid_seq:
+ init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
+ yield mid_seq, None, None, init_msgs
generator = generate(mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
disable_channels=disable_channels)
- for token_seq in generator:
+ for i, token_seq in enumerate(generator):
mid_seq.append(token_seq)
- draw_event(token_seq)
- yield mid_seq, get_img(), None, None
+ event = tokenizer.tokens2event(token_seq.tolist())
+ yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i+1, gen_events])]
mid = tokenizer.detokenize(mid_seq)
with open(f"output.mid", 'wb') as f:
f.write(MIDI.score2midi(mid))
audio = synthesis(MIDI.score2opus(mid), opt.soundfont_path)
- yield mid_seq, get_img(), "output.mid", (44100, audio)
+ yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
def cancel_run(mid_seq):
@@ -191,7 +165,7 @@ def cancel_run(mid_seq):
with open(f"output.mid", 'wb') as f:
f.write(MIDI.score2midi(mid))
audio = synthesis(MIDI.score2opus(mid), opt.soundfont_path)
- return "output.mid", (44100, audio)
+ return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
def download(url, output_file):
@@ -216,6 +190,38 @@ def download_if_not_exit(url, output_file):
raise e
+def load_javascript(dir="javascript"):
+ scripts_list = glob.glob(f"{dir}/*.js")
+ javascript = ""
+ for path in scripts_list:
+ with open(path, "r", encoding="utf8") as jsfile:
+ javascript += f"\n"
+ template_response_ori = gr.routes.templates.TemplateResponse
+
+ def template_response(*args, **kwargs):
+ res = template_response_ori(*args, **kwargs)
+ res.body = res.body.replace(
+ b'', f'{javascript}'.encode("utf8"))
+ res.init_headers()
+ return res
+
+ gr.routes.templates.TemplateResponse = template_response
+
+
+class JSMsgReceiver(gr.HTML):
+
+ def __init__(self, **kwargs):
+ super().__init__(elem_id="msg_receiver", visible=False, **kwargs)
+
+ def postprocess(self, y):
+ if y:
+ y = f"{json.dumps(y)}
"
+ return super().postprocess(y)
+
+ def get_block_name(self) -> str:
+ return "html"
+
+
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
40: "Blush", 48: "Orchestra"}
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
@@ -259,6 +265,7 @@ def download_if_not_exit(url, output_file):
input("Failed to load models, maybe you need to delete them and re-download it.\nPress any key to continue...")
exit(-1)
+ load_javascript()
app = gr.Blocks()
with app:
gr.Markdown("Midi Composer
")
@@ -269,7 +276,7 @@ def download_if_not_exit(url, output_file):
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
" for faster running"
)
-
+ js_msg = JSMsgReceiver()
tab_select = gr.Variable(value=0)
with gr.Tabs():
with gr.TabItem("instrument prompt") as tab1:
@@ -290,7 +297,7 @@ def download_if_not_exit(url, output_file):
], [input_instruments, input_drum_kit])
with gr.TabItem("midi prompt") as tab2:
input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
- input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
+ input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=1024,
step=1,
value=128)
@@ -307,14 +314,14 @@ def download_if_not_exit(url, output_file):
run_btn = gr.Button("generate", variant="primary")
stop_btn = gr.Button("stop and output")
output_midi_seq = gr.Variable()
- output_midi_img = gr.Image(label="output image")
+ output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
+ output_audio = gr.Audio(label="output audio", format="wav", elem_id="midi_audio")
output_midi = gr.File(label="output midi", file_types=[".mid"])
- output_audio = gr.Audio(label="output audio", format="wav")
run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events,
input_gen_events, input_temp, input_top_p, input_top_k,
input_allow_cc],
- [output_midi_seq, output_midi_img, output_midi, output_audio])
- stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
+ [output_midi_seq, output_midi, output_audio, js_msg])
+ stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
try:
port = opt.port
if port == -1:
diff --git a/javascript/app.js b/javascript/app.js
new file mode 100644
index 0000000..0a8d121
--- /dev/null
+++ b/javascript/app.js
@@ -0,0 +1,389 @@
+function gradioApp() {
+ const elems = document.getElementsByTagName('gradio-app')
+ const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
+ return !!gradioShadowRoot ? gradioShadowRoot : document;
+}
+
+uiUpdateCallbacks = []
+msgReceiveCallbacks = []
+
+function onUiUpdate(callback){
+ uiUpdateCallbacks.push(callback)
+}
+
+function onMsgReceive(callback){
+ msgReceiveCallbacks.push(callback)
+}
+
+function runCallback(x, m){
+ try {
+ x(m)
+ } catch (e) {
+ (console.error || console.log).call(console, e.message, e);
+ }
+}
+function executeCallbacks(queue, m) {
+ queue.forEach(function(x){runCallback(x, m)})
+}
+
+document.addEventListener("DOMContentLoaded", function() {
+ var mutationObserver = new MutationObserver(function(m){
+ executeCallbacks(uiUpdateCallbacks, m);
+ });
+ mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
+});
+
+(()=>{
+ let mse_receiver_inited = null
+ onUiUpdate(()=>{
+ let app = gradioApp()
+ let msg_receiver = app.querySelector("#msg_receiver");
+ if(!!msg_receiver && mse_receiver_inited !== msg_receiver){
+ let mutationObserver = new MutationObserver(function(ms){
+ ms.forEach((m)=>{
+ m.addedNodes.forEach((node)=>{
+ if(node.nodeName === "P"){
+ let obj = JSON.parse(node.innerText);
+ if(obj instanceof Array){
+ obj.forEach((o)=>{executeCallbacks(msgReceiveCallbacks, o);});
+ }else{
+ executeCallbacks(msgReceiveCallbacks, obj);
+ }
+ }
+ })
+ })
+ });
+ mutationObserver.observe( msg_receiver, {childList:true, subtree:true, characterData:true})
+ console.log("receiver init");
+ mse_receiver_inited = msg_receiver;
+ }
+ })
+})()
+
+function HSVtoRGB(h, s, v) {
+ let r, g, b, i, f, p, q, t;
+ i = Math.floor(h * 6);
+ f = h * 6 - i;
+ p = v * (1 - s);
+ q = v * (1 - f * s);
+ t = v * (1 - (1 - f) * s);
+ switch (i % 6) {
+ case 0: r = v; g = t; b = p; break;
+ case 1: r = q; g = v; b = p; break;
+ case 2: r = p; g = v; b = t; break;
+ case 3: r = p; g = q; b = v; break;
+ case 4: r = t; g = p; b = v; break;
+ case 5: r = v; g = p; b = q; break;
+ }
+ return {
+ r: Math.round(r * 255),
+ g: Math.round(g * 255),
+ b: Math.round(b * 255)
+ };
+}
+
+class MidiVisualizer extends HTMLElement{
+ constructor() {
+ super();
+ this.midiEvents = [];
+ this.activeNotes = [];
+ this.midiTimes = [];
+ this.wrapper = null;
+ this.svg = null;
+ this.timeLine = null;
+ this.config = {
+ noteHeight : 4,
+ beatWidth: 32
+ }
+ this.timePreBeat = 16
+ this.svgWidth = 0;
+ this.t1 = 0;
+ this.playTime = 0
+ this.playTimeMs = 0
+ this.colorMap = new Map();
+ this.playing = false;
+ this.timer = null;
+ this.init();
+ }
+
+ init(){
+ this.innerHTML=''
+ const shadow = this.attachShadow({mode: 'open'});
+ const style = document.createElement("style");
+ const wrapper = document.createElement('div');
+ style.textContent = ".note.active {stroke: black;stroke-width: 0.75;stroke-opacity: 0.75;}";
+ wrapper.style.overflowX= "scroll"
+ const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg');
+ svg.style.height = `${this.config.noteHeight*128}px`;
+ svg.style.width = `${this.svgWidth}px`;
+ const timeLine = document.createElementNS('http://www.w3.org/2000/svg', 'line');
+ timeLine.style.stroke = "green"
+ timeLine.style.strokeWidth = 2;
+ shadow.appendChild(style)
+ shadow.appendChild(wrapper);
+ wrapper.appendChild(svg);
+ svg.appendChild(timeLine)
+ this.wrapper = wrapper;
+ this.svg = svg;
+ this.timeLine= timeLine;
+ this.setPlayTime(0);
+ }
+
+ clearMidiEvents(){
+ this.pause()
+ this.midiEvents = [];
+ this.activeNotes = [];
+ this.midiTimes = [];
+ this.t1 = 0
+ this.colorMap.clear()
+ this.setPlayTime(0);
+ this.playTimeMs = 0
+ this.svgWidth = 0
+ this.svg.innerHTML = ''
+ this.svg.style.width = `${this.svgWidth}px`;
+ this.svg.appendChild(this.timeLine)
+ }
+
+ appendMidiEvent(midiEvent){
+ if(midiEvent instanceof Array && midiEvent.length > 0){
+
+ this.t1 += midiEvent[1]
+ let t = this.t1*this.timePreBeat + midiEvent[2]
+ midiEvent = [midiEvent[0], t].concat(midiEvent.slice(3))
+ if(midiEvent[0] === "note"){
+ let track = midiEvent[2]
+ let duration = midiEvent[3]
+ let channel = midiEvent[4]
+ let pitch = midiEvent[5]
+ let velocity = midiEvent[6]
+ let x = (t/this.timePreBeat)*this.config.beatWidth
+ let y = (127 - pitch)*this.config.noteHeight
+ let w = (duration/this.timePreBeat)*this.config.beatWidth
+ let h = this.config.noteHeight
+ this.svgWidth = Math.ceil(Math.max(x + w, this.svgWidth))
+ let color = this.getColor(track, channel)
+ let opacity = Math.min(1, velocity/127 + 0.1).toFixed(2)
+ let rect = this.drawNote(x,y,w,h, `rgba(${color.r}, ${color.g}, ${color.b}, ${opacity})`)
+ midiEvent.push(rect)
+ this.setPlayTime(t);
+ this.wrapper.scrollTo(this.svgWidth - this.wrapper.offsetWidth, 0)
+ }
+ this.midiEvents.push(midiEvent);
+ this.svg.style.width = `${this.svgWidth}px`;
+ }
+
+ }
+
+ getColor(track, channel){
+ let key = `${track},${channel}`;
+ let color = this.colorMap.get(key);
+ if(!!color){
+ return color;
+ }
+ color = HSVtoRGB(Math.random(),Math.random()*0.5 + 0.5,1);
+ this.colorMap.set(key, color);
+ return color;
+ }
+
+ drawNote(x, y, w, h, fill) {
+ if (!this.svg) {
+ return null;
+ }
+ const rect = document.createElementNS('http://www.w3.org/2000/svg', 'rect');
+ rect.classList.add('note');
+ rect.setAttribute('fill', fill);
+ // Round values to the nearest integer to avoid partially filled pixels.
+ rect.setAttribute('x', `${Math.round(x)}`);
+ rect.setAttribute('y', `${Math.round(y)}`);
+ rect.setAttribute('width', `${Math.round(w)}`);
+ rect.setAttribute('height', `${Math.round(h)}`);
+ this.svg.appendChild(rect);
+ return rect
+ }
+
+ finishAppendMidiEvent(){
+ this.pause()
+ let midiEvents = this.midiEvents.sort((a, b)=>a[1]-b[1])
+ let tempo = (60 / 120) * 10 ** 3
+ let ms = 0
+ let lastT = 0
+ this.midiTimes.push({ms:ms, t: 0, tempo: tempo})
+ midiEvents.forEach((midiEvent)=>{
+ let t = midiEvent[1]
+ ms += ((t- lastT) / this.timePreBeat) * tempo
+ if(midiEvent[0]==="set_tempo"){
+ tempo = (60 / midiEvent[3]) * 10 ** 3
+ this.midiTimes.push({ms:ms, t: t, tempo: tempo})
+ }
+ lastT = t
+ })
+ }
+
+ setPlayTime(t){
+ this.playTime = t
+ let x = Math.round((t/this.timePreBeat)*this.config.beatWidth)
+ this.timeLine.setAttribute('x1', `${x}`);
+ this.timeLine.setAttribute('y1', '0');
+ this.timeLine.setAttribute('x2', `${x}`);
+ this.timeLine.setAttribute('y2', `${this.config.noteHeight*128}`);
+
+ this.wrapper.scrollTo(Math.max(0, x - this.wrapper.offsetWidth/2), 0)
+
+ if(this.playing){
+ let activeNotes = []
+ this.removeActiveNotes(this.activeNotes)
+ this.midiEvents.forEach((midiEvent)=>{
+ if(midiEvent[0] === "note"){
+ let time = midiEvent[1]
+ let duration = midiEvent[3]
+ let note = midiEvent[midiEvent.length - 1]
+ if(time <=this.playTime && time+duration>= this.playTime){
+ activeNotes.push(note)
+ }
+ }
+ })
+ this.addActiveNotes(activeNotes)
+ }
+ }
+
+ setPlayTimeMs(ms){
+ this.playTimeMs = ms
+ let playTime = 0
+ for(let i =0;i=ms){
+ break;
+ }
+ playTime = midiTime.t + (ms-midiTime.ms) * this.timePreBeat / midiTime.tempo
+ }
+ this.setPlayTime(playTime)
+ }
+
+ addActiveNotes(notes){
+ notes.forEach((note)=>{
+ this.activeNotes.push(note)
+ note.classList.add('active');
+ });
+ }
+
+ removeActiveNotes(notes){
+ notes.forEach((note)=>{
+ let idx = this.activeNotes.indexOf(note)
+ if(idx>-1)
+ this.activeNotes.splice(idx, 1);
+ note.classList.remove('active');
+ });
+ }
+
+ play(){
+ this.playing = true;
+ this.timer = setInterval(() => {
+ this.setPlayTimeMs(this.playTimeMs + 10)
+ }, 10);
+ }
+
+ pause(){
+ if(!!this.timer)
+ clearInterval(this.timer)
+ this.removeActiveNotes(this.activeNotes)
+ this.timer = null;
+ this.playing = false;
+ }
+
+
+ bindAudioPlayer(audio){
+ this.pause()
+ audio.addEventListener("play", (event)=>{
+ this.play()
+ })
+ audio.addEventListener("pause", (event)=>{
+ this.pause()
+ })
+ audio.addEventListener("timeupdate", (event)=>{
+ this.setPlayTimeMs(event.target.currentTime*10**3)
+ })
+ }
+}
+
+customElements.define('midi-visualizer', MidiVisualizer);
+
+(()=>{
+ let midi_visualizer_container_inited = null
+ let midi_audio_inited = null;
+ let midi_visualizer = document.createElement('midi-visualizer')
+ onUiUpdate((m)=>{
+ let app = gradioApp()
+ let midi_visualizer_container = app.querySelector("#midi_visualizer_container");
+ if(!!midi_visualizer_container && midi_visualizer_container_inited!== midi_visualizer_container){
+ midi_visualizer_container.appendChild(midi_visualizer)
+ midi_visualizer_container_inited = midi_visualizer_container;
+ }
+ let midi_audio = app.querySelector("#midi_audio > audio");
+ if(!!midi_audio && midi_audio_inited!==midi_audio){
+ midi_visualizer.bindAudioPlayer(midi_audio)
+ midi_audio_inited = midi_audio
+ }
+ })
+
+ function createProgressBar(progressbarContainer){
+ let parentProgressbar = progressbarContainer.parentNode;
+ let divProgress = document.createElement('div');
+ divProgress.className='progressDiv';
+ let rect = progressbarContainer.getBoundingClientRect();
+ divProgress.style.width = rect.width + "px";
+ divProgress.style.background = "#b4c0cc";
+ divProgress.style.borderRadius = "8px";
+ let divInner = document.createElement('div');
+ divInner.className='progress';
+ divInner.style.color = "white";
+ divInner.style.background = "#0060df";
+ divInner.style.textAlign = "right";
+ divInner.style.fontWeight = "bold";
+ divInner.style.borderRadius = "8px";
+ divInner.style.height = "20px";
+ divInner.style.lineHeight = "20px";
+ divInner.style.paddingRight = "8px"
+ divInner.style.width = "0%";
+ divProgress.appendChild(divInner);
+ parentProgressbar.insertBefore(divProgress, progressbarContainer);
+ }
+
+ function removeProgressBar(progressbarContainer){
+ let parentProgressbar = progressbarContainer.parentNode;
+ let divProgress = parentProgressbar.querySelector(".progressDiv");
+ parentProgressbar.removeChild(divProgress);
+ }
+
+ function setProgressBar(progressbarContainer, progress, total){
+ let parentProgressbar = progressbarContainer.parentNode;
+ let divProgress = parentProgressbar.querySelector(".progressDiv");
+ let divInner = parentProgressbar.querySelector(".progress");
+ if(total===0)
+ total = 1;
+ divInner.style.width = `${(progress/total)*100}%`;
+ divInner.textContent = `${progress}/${total}`;
+ }
+
+ onMsgReceive((msg)=>{
+ switch (msg.name) {
+ case "visualizer_clear":
+ midi_visualizer.clearMidiEvents();
+ createProgressBar(midi_visualizer_container_inited)
+ break;
+ case "visualizer_append":
+ midi_visualizer.appendMidiEvent(msg.data);
+ break;
+ case "progress":
+ let progress = msg.data[0]
+ let total = msg.data[1]
+ setProgressBar(midi_visualizer_container_inited, progress, total)
+ break;
+ case "visualizer_end":
+ midi_visualizer.finishAppendMidiEvent()
+ midi_visualizer.setPlayTime(0);
+ removeProgressBar(midi_visualizer_container_inited);
+ break;
+ default:
+ }
+ })
+})();
diff --git a/midi_tokenizer.py b/midi_tokenizer.py
index 1335baf..ab287e5 100644
--- a/midi_tokenizer.py
+++ b/midi_tokenizer.py
@@ -101,6 +101,19 @@ def event2tokens(self, event):
tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
return tokens
+ def tokens2event(self, tokens):
+ if tokens[0] in self.id_events:
+ name = self.id_events[tokens[0]]
+ if len(tokens) <= len(self.events[name]):
+ return []
+ params = tokens[1:]
+ params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
+ return []
+ event = [name] + params
+ return event
+ return []
+
def detokenize(self, midi_seq):
ticks_per_beat = 480
tracks_dict = {}