Skip to content

Commit

Permalink
Formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 15, 2024
1 parent 148c5bc commit 8e7d926
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 40 deletions.
5 changes: 4 additions & 1 deletion moshi/moshi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ async def _recv_loop(self) -> None:
continue
message = message.data
if not isinstance(message, bytes):
self.printer.log("warning", f"unsupported message type {type(message)}")
self.printer.log(
"warning", f"unsupported message type {type(message)}"
)
continue
if len(message) == 0:
self.printer.log("warning", "empty message")
Expand Down Expand Up @@ -137,6 +139,7 @@ async def run(self) -> None:
self._recv_loop(), self._decoder_loop(), self._queue_loop()
)


async def run(printer: AnyPrinter, args):
uri = f"ws://{args.host}:{args.port}/api/chat"
async with aiohttp.ClientSession() as session:
Expand Down
7 changes: 6 additions & 1 deletion moshi/moshi/models/encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
from torch.nn import functional as F


from ..quantization import QuantizedResult, BaseQuantizer, SplitResidualVectorQuantizer, ResidualVectorQuantizer
from ..quantization import (
QuantizedResult,
BaseQuantizer,
SplitResidualVectorQuantizer,
ResidualVectorQuantizer,
)
from ..modules.resample import ConvDownsample1d, ConvTrUpsample1d
from ..modules.streaming import StreamingModule, State
from ..utils.compile import no_compile, CUDAGraphed
Expand Down
16 changes: 11 additions & 5 deletions moshi/moshi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def warmup(self):
async def handle_chat(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)

async def recv_loop():
nonlocal close
try:
Expand Down Expand Up @@ -203,7 +204,7 @@ async def send_loop():
self.ec.reset_streaming()
self.lm_gen.reset_streaming()
# Send the handshake.
await ws.send_bytes(b'\x00')
await ws.send_bytes(b"\x00")
await asyncio.gather(opus_loop(), recv_loop(), send_loop())
log("info", "done with connection")
return ws
Expand All @@ -214,7 +215,7 @@ def main():
log("info", "warming up the model")
state.warmup()
app = web.Application()
app.router.add_get('/api/chat', state.handle_chat)
app.router.add_get("/api/chat", state.handle_chat)
static_path: None | str = None
if args.static is None:
log("info", f"retrieving the static content")
Expand All @@ -223,13 +224,18 @@ def main():
# When set to the "none" string, we don't serve any static content.
static_path = args.static
if static_path is not None:

async def handle_root(_):
return web.FileResponse(os.path.join(static_path, 'index.html'))
return web.FileResponse(os.path.join(static_path, "index.html"))

log("info", f"serving static content from {static_path}")
app.router.add_get('/', handle_root)
app.router.add_static('/', path=static_path, follow_symlinks=True, name='static')
app.router.add_get("/", handle_root)
app.router.add_static(
"/", path=static_path, follow_symlinks=True, name="static"
)
log("info", f"listening to ws://{args.host}:{args.port}")
web.run_app(app, port=args.port)


with torch.no_grad():
main()
44 changes: 21 additions & 23 deletions moshi_mlx/moshi_mlx/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass
import sys


def colorize(text, color):
code = f"\033[{color}m"
restore = "\033[0m"
Expand All @@ -20,7 +21,7 @@ def make_log(level: str, msg: str) -> str:
prefix = colorize("Error:", "1;31")
else:
raise ValueError(f"Unknown level {level}")
return prefix + ' ' + msg
return prefix + " " + msg


class RawPrinter:
Expand All @@ -39,7 +40,7 @@ def log(self, level: str, msg: str):
print(f"{level.capitalize()}: {msg}", file=self.err_stream)

def print_lag(self):
self.err_stream.write(colorize(' [LAG]', '31'))
self.err_stream.write(colorize(" [LAG]", "31"))
self.err_stream.flush()

def print_pending(self):
Expand Down Expand Up @@ -92,7 +93,7 @@ def erase(self, count: int = 1):
else:
entries = list(self._line)
self._line.clear()
self.stream.write('\r')
self.stream.write("\r")
for entry in entries:
self._line.append(entry)
self.stream.write(entry.render())
Expand All @@ -102,16 +103,16 @@ def erase(self, count: int = 1):
def newline(self):
missing = self._max_line_length - len(self)
if missing > 0:
self.stream.write(' ' * missing)
self.stream.write('\n')
self.stream.write(" " * missing)
self.stream.write("\n")
self._line.clear()
self._max_line_length = 0
self._has_padding = False

def flush(self):
missing = self._max_line_length - len(self)
if missing > 0:
self.stream.write(' ' * missing)
self.stream.write(" " * missing)
self._has_padding = True
self.stream.flush()

Expand All @@ -126,10 +127,10 @@ def __init__(self, max_cols: int = 80, stream=sys.stdout, err_stream=sys.stderr)
self._pending_printed = False

def print_header(self):
self.line.add(' ' + '-' * (self.max_cols) + ' ')
self.line.add(" " + "-" * (self.max_cols) + " ")
self.line.newline()
self.line.flush()
self.line.add('| ')
self.line.add("| ")

def _remove_pending(self) -> bool:
if self._pending_printed:
Expand All @@ -144,41 +145,41 @@ def print_token(self, token: str, color: str | None = None):
if len(token) <= remaining:
self.line.add(token, color)
else:
end = ' ' * remaining + ' |'
if token.startswith(' '):
end = " " * remaining + " |"
if token.startswith(" "):
token = token.lstrip()
self.line.add(end)
self.line.newline()
self.line.add('| ')
self.line.add("| ")
self.line.add(token, color)
else:
assert color is None
erase_count = None
cumulated = ''
cumulated = ""
for idx, entry in enumerate(self.line._line[::-1]):
if entry.color:
# probably a LAG message
erase_count = idx
break
if entry.msg.startswith(' '):
if entry.msg.startswith(" "):
erase_count = idx + 1
cumulated = entry.msg + cumulated
break
if erase_count is not None:
if erase_count > 0:
self.line.erase(erase_count)
remaining = self.max_cols - len(self.line)
end = ' ' * remaining + ' |'
end = " " * remaining + " |"
self.line.add(end)
self.line.newline()
self.line.add('| ')
self.line.add("| ")
token = cumulated.lstrip() + token
self.line.add(token)
else:
self.line.add(token[:remaining])
self.line.add(' |')
self.line.add(" |")
self.line.newline()
self.line.add('| ')
self.line.add("| ")
self.line.add(token[remaining:])
self.line.flush()

Expand All @@ -192,20 +193,17 @@ def log(self, level: str, msg: str):
self.err_stream.flush()

def print_lag(self):
self.print_token(' [LAG]', '31')
self.print_token(" [LAG]", "31")

def print_pending(self):
chars = ['|', '/', '-', '\\']
chars = ["|", "/", "-", "\\"]
count = int(self._pending_count / 5)
char = chars[count % len(chars)]
colors = ['32', '33', '31']
colors = ["32", "33", "31"]
self._remove_pending()
self.line.add(char, colors[count % len(colors)])
self._pending_printed = True
self._pending_count += 1


AnyPrinter = Printer | RawPrinter



1 change: 1 addition & 0 deletions moshi_mlx/moshi_mlx/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def hf_hub_download(repo, path: str) -> str:
raise ValueError(f"the --hf-repo flag is required to retrieve {path}")
return huggingface_hub.hf_hub_download(repo, path)


class Stats:
send_times: tp.List[float] = []
model_times: tp.List[tp.Tuple[float, float]] = []
Expand Down
28 changes: 18 additions & 10 deletions moshi_mlx/moshi_mlx/local_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
FRAME_SIZE = 1920
CHANNELS = 1


def colorize(text, color):
code = f"\033[{color}m"
restore = "\033[0m"
return "".join([code, text, restore])


def log(level: str, msg: str):
if level == "warning":
prefix = colorize("[Warn]", "1;31")
Expand All @@ -55,6 +57,7 @@ def hf_hub_download(repo, path: str) -> str:
raise ValueError(f"the --hf-repo flag is required to retrieve {path}")
return huggingface_hub.hf_hub_download(repo, path)


class Stats:
send_times: tp.List[float] = []
model_times: tp.List[tp.Tuple[float, float]] = []
Expand Down Expand Up @@ -216,9 +219,11 @@ async def recv_loop2():
continue

lock = asyncio.Lock()

async def handle_chat(request):
ws = web.WebSocketResponse()
await ws.prepare(request)

async def recv_loop():
nonlocal close
try:
Expand Down Expand Up @@ -263,8 +268,8 @@ async def opus_loop():
else:
all_pcm_data = np.concatenate((all_pcm_data, pcm))
while all_pcm_data.shape[-1] >= FRAME_SIZE:
chunk = all_pcm_data[: FRAME_SIZE]
all_pcm_data = all_pcm_data[FRAME_SIZE :]
chunk = all_pcm_data[:FRAME_SIZE]
all_pcm_data = all_pcm_data[FRAME_SIZE:]
input_queue.put_nowait(chunk)

async def send_loop():
Expand Down Expand Up @@ -299,39 +304,42 @@ async def another_loop():
opus_writer = sphn.OpusStreamWriter(SAMPLE_RATE)
opus_reader = sphn.OpusStreamReader(SAMPLE_RATE)
# Send the handshake.
await ws.send_bytes(b'\x00')
await ws.send_bytes(b"\x00")
await asyncio.gather(opus_loop(), recv_loop(), send_loop(), another_loop())
log("info", "done with connection")
return ws


async def go():
app = web.Application()
app.router.add_get('/api/chat', handle_chat)
app.router.add_get("/api/chat", handle_chat)
static_path: None | str = None
if args.static is None:
log("info", f"retrieving the static content")
dist_tgz = hf_hub_download(args.hf_repo, "dist.tgz")
dist_tgz = Path(dist_tgz)
dist = dist_tgz.parent / "dist"
if not dist.exists():
with tarfile.open(dist_tgz, 'r:gz') as tar:
with tarfile.open(dist_tgz, "r:gz") as tar:
tar.extractall(path=dist_tgz.parent)
static_path = str(dist)
elif args.static != "none":
# When set to the "none" string, we don't serve any static content.
static_path = args.static
if static_path is not None:

async def handle_root(_):
return web.FileResponse(os.path.join(static_path, 'index.html'))
return web.FileResponse(os.path.join(static_path, "index.html"))

log("info", f"serving static content from {static_path}")
app.router.add_get('/', handle_root)
app.router.add_static('/', path=static_path, name='static')
app.router.add_get("/", handle_root)
app.router.add_static("/", path=static_path, name="static")
log("info", f"listening to ws://{args.host}:{args.port}")
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, args.host, args.port)
await asyncio.gather(recv_loop(), send_loop(), recv_loop2(), send_loop2(), site.start())
await asyncio.gather(
recv_loop(), send_loop(), recv_loop2(), send_loop2(), site.start()
)
await runner.cleanup()

try:
Expand Down

0 comments on commit 8e7d926

Please sign in to comment.