Skip to content

Commit

Permalink
Merge branch 'main' into matteo/multiplart-v3
Browse files Browse the repository at this point in the history
  • Loading branch information
chamini2 authored Nov 29, 2024
2 parents 9a761d4 + d9d552c commit 3bf65c2
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 3 deletions.
130 changes: 129 additions & 1 deletion projects/fal/src/fal/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Iterator
from typing import TYPE_CHECKING, Any, Iterator

import httpx

from fal import flags
from fal.sdk import Credentials, get_default_credentials

if TYPE_CHECKING:
from websockets.sync.connection import Connection

_QUEUE_URL_FORMAT = f"https://queue.{flags.FAL_RUN_HOST}/{{app_id}}"
_REALTIME_URL_FORMAT = f"wss://{flags.FAL_RUN_HOST}/{{app_id}}"
_WS_URL_FORMAT = f"wss://ws.{flags.FAL_RUN_HOST}/{{app_id}}"


def _backwards_compatible_app_id(app_id: str) -> str:
Expand Down Expand Up @@ -245,3 +249,127 @@ def _connect(app_id: str, *, path: str = "/realtime") -> Iterator[_RealtimeConne
url, additional_headers=creds.to_headers(), open_timeout=90
) as ws:
yield _RealtimeConnection(ws)


class _MetaMessageFound(Exception): ...


@dataclass
class _WSConnection:
"""A WS connection to an HTTP Fal app."""

_ws: Connection
_buffer: str | bytes | None = None

def run(self, arguments: dict[str, Any]) -> dict[str, Any]:
"""Run an inference task on the app and return the result."""
self.send(arguments)
return self.recv()

def send(self, arguments: dict[str, Any]) -> None:
import json

payload = json.dumps(arguments)
self._ws.send(payload)

def _peek(self) -> bytes | str:
if self._buffer is None:
self._buffer = self._ws.recv()

return self._buffer

def _consume(self) -> None:
if self._buffer is None:
raise ValueError("No data to consume")

self._buffer = None

@contextmanager
def _recv(self) -> Iterator[str | bytes]:
res = self._peek()

yield res

# Only consume if it went through the context manager without raising
self._consume()

def _is_meta(self, res: str | bytes) -> bool:
if not isinstance(res, str):
return False

try:
json_payload: Any = json.loads(res)
except json.JSONDecodeError:
return False

if not isinstance(json_payload, dict):
return False

return "type" in json_payload and "request_id" in json_payload

def _recv_meta(self, type: str) -> dict[str, Any]:
with self._recv() as res:
if not self._is_meta(res):
raise ValueError(f"Expected a {type} message")

json_payload: dict = json.loads(res)
if json_payload.get("type") != type:
raise ValueError(f"Expected a {type} message")

return json_payload

def _recv_response(self) -> Any:
import msgpack

body: bytes = b""
while True:
try:
with self._recv() as res:
if self._is_meta(res):
# Keep the meta message for later
raise _MetaMessageFound()

if isinstance(res, str):
return res
else:
body += res
except _MetaMessageFound:
break

if not body:
raise ValueError("Empty response body")

return msgpack.unpackb(body)

def recv(self) -> Any:
start = self._recv_meta("start")
request_id = start["request_id"]

response = self._recv_response()

end = self._recv_meta("end")
if end["request_id"] != request_id:
raise ValueError("Mismatched request_id in end message")

return response


@contextmanager
def ws(app_id: str, *, path: str = "") -> Iterator[_WSConnection]:
"""Connect to a HTTP endpoint but with websocket protocol. This is an internal and
experimental API, use it at your own risk."""

from websockets.sync import client

app_id = _backwards_compatible_app_id(app_id)
url = _WS_URL_FORMAT.format(app_id=app_id)
if path:
_path = path[len("/") :] if path.startswith("/") else path
url += "/" + _path

creds = get_default_credentials()

with client.connect(
url, additional_headers=creds.to_headers(), open_timeout=90
) as ws:
yield _WSConnection(ws)
43 changes: 43 additions & 0 deletions projects/fal/src/fal/cli/machine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from .parser import FalClientParser


def _kill(args):
from fal.sdk import FalServerlessClient

client = FalServerlessClient(args.host)
with client.connect() as connection:
connection.kill_runner(args.id)


def _add_kill_parser(subparsers, parents):
kill_help = "Kill a machine."
parser = subparsers.add_parser(
"kill",
description=kill_help,
help=kill_help,
parents=parents,
)
parser.add_argument(
"id",
help="Runner ID.",
)
parser.set_defaults(func=_kill)


def add_parser(main_subparsers, parents):
machine_help = "Manage fal machines."
parser = main_subparsers.add_parser(
"machine",
description=machine_help,
help=machine_help,
parents=parents,
)

subparsers = parser.add_subparsers(
title="Commands",
metavar="command",
required=True,
parser_class=FalClientParser,
)

_add_kill_parser(subparsers, parents)
4 changes: 2 additions & 2 deletions projects/fal/src/fal/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fal.console import console
from fal.console.icons import CROSS_ICON

from . import apps, auth, create, deploy, doctor, keys, run, secrets
from . import apps, auth, create, deploy, doctor, keys, machine, run, secrets
from .debug import debugtools, get_debug_parser
from .parser import FalParser, FalParserExit

Expand All @@ -31,7 +31,7 @@ def _get_main_parser() -> argparse.ArgumentParser:
required=True,
)

for cmd in [auth, apps, deploy, run, keys, secrets, doctor, create]:
for cmd in [auth, apps, deploy, run, keys, secrets, doctor, create, machine]:
cmd.add_parser(subparsers, parents)

return parser
Expand Down
4 changes: 4 additions & 0 deletions projects/fal/src/fal/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,7 @@ def list_secrets(self) -> list[ServerlessSecret]:
)
for secret in response.secrets
]

def kill_runner(self, runner_id: str) -> None:
request = isolate_proto.KillRunnerRequest(runner_id=runner_id)
self.stub.KillRunner(request)
71 changes: 71 additions & 0 deletions projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ def decrement(self, input: StatefulInput) -> Output:
return Output(result=self.counter)


class SleepInput(BaseModel):
wait_time: int


class SleepOutput(BaseModel):
pass


class SleepApp(fal.App, keep_alive=300, max_concurrency=1):
machine_type = "XS"

@fal.endpoint("/")
async def sleep(self, input: SleepInput) -> SleepOutput:
await asyncio.sleep(input.wait_time)
return SleepOutput()


class ExceptionApp(fal.App, keep_alive=300, max_concurrency=1):
machine_type = "XS"

Expand Down Expand Up @@ -378,6 +395,21 @@ def test_app_client(test_app: str, test_nomad_app: str):
assert response["result"] == 5


def test_ws_client(test_app: str):
with apps.ws(test_app) as connection:
for i in range(3):
response = json.loads(connection.run({"lhs": 1, "rhs": i}))
assert response["result"] == 1 + i

for i in range(3):
connection.send({"lhs": 2, "rhs": i})

for i in range(3):
# they should be in order
response = json.loads(connection.recv())
assert response["result"] == 2 + i


def test_app_client_old_format(test_app: str):
assert test_app.count("/") == 1, "Test app should be in new format"
old_format = test_app.replace("/", "-")
Expand Down Expand Up @@ -772,3 +804,42 @@ def test_app_exceptions(test_exception_app: AppClient):

assert cuda_exc.value.status_code == _CUDA_OOM_STATUS_CODE
assert _CUDA_OOM_MESSAGE in cuda_exc.value.message


def test_kill_runner():
import uuid

app_alias = str(uuid.uuid4()) + "-sleep-alias"
app = fal.wrap_app(SleepApp)
app_revision = app.host.register(
func=app.func,
options=app.options,
application_name=app_alias,
application_auth_mode="private",
)

host: api.FalServerlessHost = app.host # type: ignore

user = _get_user()

handle = apps.submit(f"{user.user_id}/{app_revision}", arguments={"wait_time": 10})

while True:
status = handle.status()
if isinstance(status, apps.InProgress):
break
elif isinstance(status, apps.Queued):
time.sleep(1)
else:
raise Exception(f"Failed to start the app: {status}")

with host._connection as client:
try:
client.kill_runner("1234567890")
except Exception as e:
assert "not found" in str(e).lower()

runners = client.list_alias_runners(app_alias)
assert len(runners) == 1

client.kill_runner(runners[0].runner_id)

0 comments on commit 3bf65c2

Please sign in to comment.