From a9a57885845c7cd0dc3e5fc5ded156268552d985 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 22 Aug 2024 14:59:11 -0400 Subject: [PATCH 01/20] [docs] Add a tip on L1 bandwidth (#142) --- docs/amdgpu_kernel_optimization_guide.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/amdgpu_kernel_optimization_guide.md b/docs/amdgpu_kernel_optimization_guide.md index bf597cd94..09c5b59f9 100644 --- a/docs/amdgpu_kernel_optimization_guide.md +++ b/docs/amdgpu_kernel_optimization_guide.md @@ -4,7 +4,7 @@ Author: Jakub Kuderski @kuhar Date: 2024-06-24 -Last Update: 2024-08-14 +Last Update: 2024-08-22 ## Introduction @@ -280,6 +280,11 @@ at once. A sequence of up to 4 adjacent `global_load_dwordx4` instructions (implicitly) forms a *clause* that translates to a single data fabric transaction. +> [!TIP] +> To achieve peak L1 bandwidth, make sure that your memory access engages all +> four L1 cache sets. That is, at the level of the workgroup, you should be +> loading 4 cache lines (128 B) that each map to a different cache set. + > [!TIP] > For data that is 'streamed' and does not need to be cached, consider > using *non-temporal* loads/stores. This disables coherency and invalidates From 224393acf40ed2f1ce321d8fadea8cc25eacd3a5 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 22 Aug 2024 16:27:15 -0700 Subject: [PATCH 02/20] [shortfin] Factor out shortfin.interop.fastapi.FastAPIResponder and add test. --- .../shortfin/interop/fastapi/__init__.py | 113 +++++++++++ libshortfin/examples/python/fastapi/server.py | 133 +++++++++++++ .../examples/python/http/http_server.py | 180 ------------------ .../async_test.py} | 2 +- libshortfin/tests/examples/fastapi_test.py | 109 +++++++++++ 5 files changed, 356 insertions(+), 181 deletions(-) create mode 100644 libshortfin/bindings/python/shortfin/interop/fastapi/__init__.py create mode 100644 libshortfin/examples/python/fastapi/server.py delete mode 100644 libshortfin/examples/python/http/http_server.py rename libshortfin/tests/{examples_test.py => examples/async_test.py} (93%) create mode 100644 libshortfin/tests/examples/fastapi_test.py diff --git a/libshortfin/bindings/python/shortfin/interop/fastapi/__init__.py b/libshortfin/bindings/python/shortfin/interop/fastapi/__init__.py new file mode 100644 index 000000000..2cff38342 --- /dev/null +++ b/libshortfin/bindings/python/shortfin/interop/fastapi/__init__.py @@ -0,0 +1,113 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio + +try: + from fastapi import Request, Response + from fastapi.responses import StreamingResponse +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Shortfin fastapi interop requires fastapi to be installed" + ) from e + + +class FastAPIResponder: + """Bridge between FastAPI and shortfin that can be used to send out of band + responses back to a waiting FastAPI async request. + + This isn't really shortfin specific and can be used to bridge to any non + webserver owned loop. + + It is typically used by putting it in a Message that is sent to some processing + queue. Then return/awaiting it from an API callback. Example: + + ``` + @app.get("/predict") + async def predict(value: int, request: Request): + message = RequestMessage(value, FastAPIResponder(request)) + system.request_writer(message) + return await message.responder.response + ``` + + See: examples/python/fastapi/server.py + """ + + def __init__(self, request: Request): + super().__init__() + self.request = request + # Capture the running loop so that we can send responses back. + self._loop = asyncio.get_running_loop() + self.response = asyncio.Future(loop=self._loop) + self._responded = False + self._streaming_queue: asyncio.Queue | None = None + self.is_disconnected = False + + def close_with_error(self): + # Called in a failsafe fashion as part of exception handlers seeking to + # shutdown the response. If not yet responded, this will response with + # a status code of 500. If streaming, then None will be streamed. + if self._responded: + if self._streaming_queue: + self.stream_part(None) + else: + self.send_response(Response(status_code=500)) + + def send_response(self, response: Response): + """Sends a response back for this transaction. + + This is intended for sending single part responses back. See + start_response() for sending back a streaming, multi-part response. + """ + assert not self._responded, "Response already sent" + if self._loop.is_closed(): + raise IOError("Web server is shut down") + self._responded = True + self._loop.call_soon_threadsafe(self.response.set_result, response) + + def start_response(self, **kwargs): + """Starts a streaming response, passing the given kwargs to the + fastapi.responses.StreamingResponse constructor. + + This is appropriate to use for generating a sparse response stream as is + typical of chat apps. As it will hop threads for each part, other means should + be used for bulk transfer (i.e. by scheduling on the webserver loop + directly). + """ + assert not self._responded, "Response already sent" + if self._loop.is_closed(): + raise IOError("Web server is shut down") + self._responded = True + self._streaming_queue = asyncio.Queue() + + async def gen(request, streaming_queue): + while True: + if await request.is_disconnected(): + self.is_disconnected = True + part = await streaming_queue.get() + if part is None: + break + yield part + + def start(request, streaming_queue, response_future): + response = StreamingResponse(gen(request, streaming_queue), **kwargs) + response_future.set_result(response) + + self._loop.call_soon_threadsafe( + start, self.request, self._streaming_queue, self.response + ) + + def stream_part(self, content: bytes | None): + """Streams content to a response started with start_response(). + + Streaming must be ended by sending None. + """ + assert self._streaming_queue is not None, "start_response() not called" + if self._loop.is_closed(): + raise IOError("Web server is shut down") + self._loop.call_soon_threadsafe(self._streaming_queue.put_nowait, content) + if content is None: + self._streaming_queue = None diff --git a/libshortfin/examples/python/fastapi/server.py b/libshortfin/examples/python/fastapi/server.py new file mode 100644 index 000000000..66ab37b75 --- /dev/null +++ b/libshortfin/examples/python/fastapi/server.py @@ -0,0 +1,133 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import asyncio +import traceback +from contextlib import asynccontextmanager +import json +import threading +import sys + +from fastapi import FastAPI, Request, Response +from fastapi.responses import JSONResponse +import shortfin as sf +from shortfin.interop.fastapi import FastAPIResponder +import uvicorn + + +class RequestMessage(sf.Message): + def __init__(self, request_value: int, responder: FastAPIResponder): + super().__init__() + self.request_value = request_value + self.responder = responder + + +class System: + def __init__(self): + self.ls = sf.host.CPUSystemBuilder().create_system() + # TODO: Come up with an easier bootstrap thing than manually + # running a thread. + self.t = threading.Thread(target=lambda: self.ls.run(self.run())) + self.request_queue = self.ls.create_queue("request") + self.request_writer = self.request_queue.writer() + + def start(self): + self.t.start() + + def shutdown(self): + self.request_queue.close() + + async def run(self): + print("*** Sytem Running ***") + request_reader = self.request_queue.reader() + while request := await request_reader(): + try: + responder = request.responder + if request.request_value == 0: + raise ValueError("Something broke") + elif request.request_value > 20: + responder.send_response(Response(status_code=400)) + elif request.request_value == 1: + # Send a single response. + responder.send_response( + JSONResponse({"answer": request.request_value}) + ) + else: + # Stream responses from 0..value + responder.start_response() + for i in range(request.request_value + 1): + if responder.is_disconnected: + continue + responder.stream_part( + (json.dumps({"answer": i}) + "\n\0").encode() + ) + await asyncio.sleep(0.01) + responder.stream_part(None) + except Exception as e: + responder.close_with_error() + traceback.print_exc() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + system.start() + yield + print("Shutting down shortfin") + system.shutdown() + + +system = System() +app = FastAPI(lifespan=lifespan) + + +@app.get("/predict") +async def predict(value: int, request: Request): + message = RequestMessage(value, FastAPIResponder(request)) + system.request_writer(message) + return await message.responder.response + + +@app.get("/health") +async def health() -> Response: + return Response(status_code=200) + + +def main(argv): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="Root path to use for installing behind path based proxy.", + ) + parser.add_argument( + "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" + ) + parser.add_argument( + "--testing-mock-service", + action="store_true", + help="Enable the mock testing service", + ) + parser.add_argument( + "--device-uri", type=str, default="local-task", help="Device URI to serve on" + ) + + args = parser.parse_args(argv) + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=args.timeout_keep_alive, + ) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/libshortfin/examples/python/http/http_server.py b/libshortfin/examples/python/http/http_server.py deleted file mode 100644 index 43b62f06d..000000000 --- a/libshortfin/examples/python/http/http_server.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import argparse -import asyncio -from contextlib import asynccontextmanager -import threading -import sys - -from fastapi import FastAPI, Request, Response -from fastapi.responses import JSONResponse, StreamingResponse -import shortfin as sf -import uvicorn - - -class FastAPIResponder(sf.Message): - """Bridge between FastAPI and shortfin that can be put on a queue and used to - send a response back at an arbitrary point. - - This object is constructed in a FastAPI handler, capturing the current event loop - used by the web server. Then it can be put on a shortfin Queue and once within - a shortfin worker, an arbitrary worker can call `send_response` to send a simple - FastAPI response back to the webserver loop and onto the client. - - """ - - def __init__(self, request: Request): - super().__init__() - self.request = request - # Capture the running loop so that we can send responses back. - self._loop = asyncio.get_running_loop() - self.response = asyncio.Future(loop=self._loop) - self._responded = False - self._streaming_queue: asyncio.Queue | None = None - self.is_disconnected = False - - def send_response(self, response: Response): - """Sends a response back for this transaction. - - This is intended for sending single part responses back. See - start_response() for sending back a streaming, multi-part response. - """ - assert not self._responded, "Response already sent" - if self._loop.is_closed(): - raise IOError("Web server is shut down") - self._responded = True - self._loop.call_soon_threadsafe(self.response.set_result, response) - - def start_response(self, **kwargs): - """Starts a streaming response, passing the given kwargs to the - fastapi.responses.StreamingResponse constructor. - - This is appropriate to use for generating a sparse response stream as is - typical of chat apps. As it will hop threads for each part, other means should - be used for bulk transfer (i.e. by scheduling on the webserver loop - directly). - """ - assert not self._responded, "Response already sent" - if self._loop.is_closed(): - raise IOError("Web server is shut down") - self._responded = True - self._streaming_queue = asyncio.Queue() - - async def gen(): - while True: - if await self.request.is_disconnected(): - self.is_disconnected = True - part = await self._streaming_queue.get() - if part is None: - break - yield part - - def start(): - response = StreamingResponse(gen(), **kwargs) - self.response.set_result(response) - - self._loop.call_soon_threadsafe(start) - - def stream_part(self, content: bytes | None): - """Streams content to a response started with start_response(). - - Streaming must be ended by sending None. - """ - assert self._streaming_queue is not None, "start_response() not called" - if self._loop.is_closed(): - raise IOError("Web server is shut down") - self._loop.call_soon_threadsafe(self._streaming_queue.put_nowait, content) - - -class System: - def __init__(self): - self.ls = sf.host.CPUSystemBuilder().create_system() - # TODO: Come up with an easier bootstrap thing than manually - # running a thread. - self.t = threading.Thread(target=lambda: self.ls.run(self.run())) - self.request_queue = self.ls.create_queue("request") - self.request_writer = self.request_queue.writer() - - def start(self): - self.t.start() - - def shutdown(self): - self.request_queue.close() - - async def run(self): - print("*** Sytem Running ***") - request_reader = self.request_queue.reader() - while responder := await request_reader(): - print("Got request:", responder) - # Can send a single response: - # request.send_response(JSONResponse({"answer": 42})) - # Or stream: - responder.start_response() - for i in range(20): - if responder.is_disconnected: - print("Cancelled!") - break - responder.stream_part(f"Iteration {i}\n".encode()) - await asyncio.sleep(0.2) - else: - responder.stream_part(None) - - -@asynccontextmanager -async def lifespan(app: FastAPI): - system.start() - yield - print("Shutting down shortfin") - system.shutdown() - - -system = System() -app = FastAPI(lifespan=lifespan) - - -@app.get("/predict") -async def predict(request: Request): - transaction = FastAPIResponder(request) - system.request_writer(transaction) - return await transaction.response - - -def main(argv): - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default=None) - parser.add_argument("--port", type=int, default=8000) - parser.add_argument( - "--root-path", - type=str, - default=None, - help="Root path to use for installing behind path based proxy.", - ) - parser.add_argument( - "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" - ) - parser.add_argument( - "--testing-mock-service", - action="store_true", - help="Enable the mock testing service", - ) - parser.add_argument( - "--device-uri", type=str, default="local-task", help="Device URI to serve on" - ) - - args = parser.parse_args(argv) - - uvicorn.run( - app, - host=args.host, - port=args.port, - log_level="debug", - timeout_keep_alive=args.timeout_keep_alive, - ) - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/libshortfin/tests/examples_test.py b/libshortfin/tests/examples/async_test.py similarity index 93% rename from libshortfin/tests/examples_test.py rename to libshortfin/tests/examples/async_test.py index 54b815cc4..1595d7d8e 100644 --- a/libshortfin/tests/examples_test.py +++ b/libshortfin/tests/examples/async_test.py @@ -11,7 +11,7 @@ import subprocess import sys -project_dir = Path(__file__).resolve().parent.parent +project_dir = Path(__file__).resolve().parent.parent.parent example_dir = project_dir / "examples" / "python" diff --git a/libshortfin/tests/examples/fastapi_test.py b/libshortfin/tests/examples/fastapi_test.py new file mode 100644 index 000000000..f19c1c12f --- /dev/null +++ b/libshortfin/tests/examples/fastapi_test.py @@ -0,0 +1,109 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from contextlib import closing +import os +from pathlib import Path +import pytest +import requests +import socket +import subprocess +import sys +import time + +project_dir = Path(__file__).resolve().parent.parent.parent +example_dir = project_dir / "examples" / "python" + + +@pytest.fixture(scope="session") +def server(): + runner = ServerRunner([]) + yield runner + print("Sending kill signal") + runner.process.terminate() + print("Waiting for server to exit") + runner.process.wait(20) + + +# Test error first to make sure it doesn't mess up the server. +def test_error_response(server): + resp = requests.get(f"{server.url}/predict?value=0") + assert resp.status_code == 500 + + +def test_single_response(server): + resp = requests.get(f"{server.url}/predict?value=1") + resp.raise_for_status() + full_contents = resp.content + print(full_contents) + assert full_contents == b'{"answer":1}' + + +def test_stream_response(server): + resp = requests.get(f"{server.url}/predict?value=20") + resp.raise_for_status() + full_contents = resp.content + print(full_contents) + exp_contents = ("".join(['{"answer": %s}\n\x00' % i for i in range(21)])).encode() + assert full_contents == exp_contents + + +class ServerRunner: + def __init__(self, args): + port = str(find_free_port()) + self.url = "http://localhost:" + port + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + self.process = subprocess.Popen( + [ + sys.executable, + str(example_dir / "fastapi" / "server.py"), + "--port=" + port, + ] + + args, + env=env, + # TODO: Have a more robust way of forking a subprocess. + cwd=str(example_dir), + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_ready() + + def _wait_for_ready(self): + start = time.time() + while True: + try: + if requests.get(f"{self.url}/health").status_code == 200: + return + except Exception as e: + if self.process.poll() is not None: + raise RuntimeError("API server processs terminated") from e + time.sleep(1.0) + if (time.time() - start) > 30: + raise RuntimeError("Timeout waiting for server start") + + def __del__(self): + try: + process = self.process + except AttributeError: + pass + else: + process.terminate() + process.wait() + + +def find_free_port(): + """This tries to find a free port to run a server on for the test. + + Race conditions are possible - the port can be acquired between when this + runs and when the server starts. + + https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number + """ + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("localhost", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] From 388bbfac0b17a5bca394cbe6133bf5db1f2236c2 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 22 Aug 2024 19:54:57 -0700 Subject: [PATCH 03/20] [shortfin] Run down some bugs discovered with debug/asan built CPython. * The coroutine for a PyProcess was not being properly waited, resulting in two issues: 1. Exceptions were getting swallowed, and 2. If the GC hits right/wrong and the debug infra notices that something goes unawaited, it will notify the global loop exception handler, which will try to report it. Since this often happens at process shutdown, it is a toss up as to what the outcome will be when trying to log a failure, since many things will be in a partially destructed state. * array::storage was not retaining the backing device. Since arrays have a lifetime independent from the system, this could result in them outliving the device, causing crashes/corruption. The way the back reference is kept makes this safe but is still not ideal. TODO left. * Fixes worker PyThreadState cleanup sequence. It was triggering an assertion in Python debug builds. * Makes some changes to the mobilenet demo that were used to debug the system. --- libshortfin/CMakeLists.txt | 10 +++ .../python/_shortfin/asyncio_bridge.py | 3 - libshortfin/bindings/python/lib_ext.cc | 73 ++++++++++--------- .../{server.py => inference_system.py} | 57 +++++++++++++-- libshortfin/src/shortfin/array/storage.h | 12 ++- libshortfin/src/shortfin/local/process.cc | 5 +- libshortfin/src/shortfin/local/worker.h | 2 + 7 files changed, 112 insertions(+), 50 deletions(-) rename libshortfin/examples/python/mobilenet_server/{server.py => inference_system.py} (53%) diff --git a/libshortfin/CMakeLists.txt b/libshortfin/CMakeLists.txt index 20571005c..8794a2229 100644 --- a/libshortfin/CMakeLists.txt +++ b/libshortfin/CMakeLists.txt @@ -35,6 +35,16 @@ option(SHORTFIN_BUILD_TESTS "Builds C++ tests" ON) option(SHORTFIN_BUNDLE_DEPS "Download dependencies instead of using system libraries" OFF) set(SHORTFIN_IREE_SOURCE_DIR "" CACHE FILEPATH "Path to IREE source") +# Enabling ASAN. Note that this will work best if building in a completely +# bundled fashion and with an ASAN rigged CPython. Otherwise, various LD_PRELOAD +# hacks are needed. This is merely a develope convenience: people are more +# than welcome to set flags themselves. +option(SHORTFIN_ENABLE_ASAN "Enable ASAN" OFF) +if(SHORTFIN_ENABLE_ASAN) + add_compile_options(-fsanitize=address) + add_link_options(-fsanitize=address) +endif() + include(FetchContent) # Includes. diff --git a/libshortfin/bindings/python/_shortfin/asyncio_bridge.py b/libshortfin/bindings/python/_shortfin/asyncio_bridge.py index 0ef214527..63ded30e9 100644 --- a/libshortfin/bindings/python/_shortfin/asyncio_bridge.py +++ b/libshortfin/bindings/python/_shortfin/asyncio_bridge.py @@ -5,9 +5,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import asyncio -from collections.abc import Callable -from contextvars import Context -from typing_extensions import Unpack from . import lib as sfl diff --git a/libshortfin/bindings/python/lib_ext.cc b/libshortfin/bindings/python/lib_ext.cc index 15070afaa..58d20532a 100644 --- a/libshortfin/bindings/python/lib_ext.cc +++ b/libshortfin/bindings/python/lib_ext.cc @@ -94,22 +94,19 @@ class PyWorkerExtension : public local::Worker::Extension { py::gil_scoped_acquire g; loop_.reset(); - // Scrub thread state if not donated. - if (worker().options().owned_thread) { - PyThreadState_Clear(PyThreadState_Get()); - } else { - // Otherwise, juse reset the event loop. - refs_->asyncio_set_event_loop(py::none()); - refs_->asyncio_set_running_loop(py::none()); - } + // reset the event loop. + refs_->asyncio_set_event_loop(py::none()); + refs_->asyncio_set_running_loop(py::none()); } // And destroy our thread state (if not donated). - // TODO: PyThreadState_Delete seems like it should be used here, but I - // couldn't find that being done and I couldn't find a way to use it - // with the GIL/thread state correct. if (worker().options().owned_thread) { - PyThreadState_Swap(nullptr); + // Ordinarily PyGILState_Ensure must be balanced with PyGILState_Release, + // by PyThreadState_DeleteCurrent() implicitly releases it as part of + // its cleanup process. + PyGILState_STATE gil_state = PyGILState_Ensure(); + PyThreadState_Clear(PyThreadState_Get()); + PyThreadState_DeleteCurrent(); } } @@ -150,29 +147,33 @@ class PyProcess : public local::detail::BaseProcess { std::bind(&PyProcess::RunOnWorker, self_object)); } static void RunOnWorker(py::handle self_handle) { - { - py::gil_scoped_acquire g; - // Steal the reference back from ScheduleOnWorker. Important: this is - // very likely the last reference to the process. So self must not be - // touched after self_object goes out of scope. - py::object self_object = py::steal(self_handle); - PyProcess *self = py::cast(self_handle); - // We assume that the run method either returns None (def) or a coroutine - // (async def). - auto coro = self_object.attr("run")(); - if (!coro.is_none()) { - auto task = self->refs_->asyncio_create_task(coro); - // Capture the self object to avoid lifetime hazzard with PyProcess - // going away before done. - task.attr("add_done_callback")( - py::cpp_function([self_object](py::handle future) { - PyProcess *done_self = py::cast(self_object); - done_self->Terminate(); - })); - } else { - // Synchronous termination. - self->Terminate(); - } + py::gil_scoped_acquire g; + // Steal the reference back from ScheduleOnWorker. Important: this is + // very likely the last reference to the process. So self must not be + // touched after self_object goes out of scope. + py::object self_object = py::steal(self_handle); + PyProcess *self = py::cast(self_handle); + // We assume that the run method either returns None (def) or a coroutine + // (async def). + auto coro = self_object.attr("run")(); + if (!coro.is_none()) { + auto task = self->refs_->asyncio_create_task(coro); + // Capture the self object to avoid lifetime hazzard with PyProcess + // going away before done. + task.attr("add_done_callback")( + py::cpp_function([self_object](py::handle future) { + PyProcess *done_self = py::cast(self_object); + done_self->Terminate(); + // The result of the process future doesn't matter to us, but it + // may be carrying an exception and this is our only chance to + // bubble it. If it is, this will throw and be handled by the + // last chance exception handler in the worker. + // TODO: Route process termination and exceptions to a supervisor. + future.attr("result")(); + })); + } else { + // Synchronous termination. + self->Terminate(); } } @@ -341,6 +342,8 @@ void BindLocal(py::module_ &m) { return self.CreateWorker(options); }, py::arg("name"), py::rv_policy::reference_internal) + .def_prop_ro("init_worker", &local::System::init_worker, + py::rv_policy::reference_internal) .def( "run", [refs](local::System &self, py::object coro) { diff --git a/libshortfin/examples/python/mobilenet_server/server.py b/libshortfin/examples/python/mobilenet_server/inference_system.py similarity index 53% rename from libshortfin/examples/python/mobilenet_server/server.py rename to libshortfin/examples/python/mobilenet_server/inference_system.py index c8f6484bf..e2be35910 100644 --- a/libshortfin/examples/python/mobilenet_server/server.py +++ b/libshortfin/examples/python/mobilenet_server/inference_system.py @@ -7,35 +7,59 @@ import asyncio from pathlib import Path +import sys import shortfin as sf +import shortfin.array as sfnp + +MAX_BATCH = 8 + + +class InferenceRequest(sf.Message): + def __init__(self, raw_image_data): + super().__init__() + self.raw_image_data = raw_image_data class InferenceProcess(sf.Process): - def __init__(self, program, **kwargs): + def __init__(self, program, request_queue, **kwargs): super().__init__(**kwargs) self.program = program + self.request_reader = request_queue.reader() + self.device = self.scope.device(0) + self.host_staging = sfnp.host_array( + self.device, [MAX_BATCH, 3, 224, 224], sfnp.float32 + ) + self.device_input = sfnp.device_array( + self.device, [MAX_BATCH, 3, 224, 224], sfnp.float32 + ) async def run(self): print(f"Inference process: {self.pid}") + while request := await self.request_reader(): + print(f"[{self.pid}] Got request {request}") + # self.host_staging.data = self.raw_image_data class Main: def __init__(self, lsys: sf.System, home_dir: Path): - self.processes_per_worker = 4 + self.processes_per_worker = 1 self.lsys = lsys self.home_dir = home_dir + self.request_queue = lsys.create_queue("request") self.program_module = self.lsys.load_module(home_dir / "model.vmfb") print(f"Loaded: {self.program_module}") self.processes = [] - async def initialize(self, scope): + async def start_scope(self, scope): # Note that currently, program load is synchronous. But we do it # in a task so we can await it in the future and let program loads # overlap. program = scope.load_unbound_program([self.program_module]) for _ in range(self.processes_per_worker): - self.processes.append(InferenceProcess(program, scope=scope).launch()) + self.processes.append( + InferenceProcess(program, self.request_queue, scope=scope).launch() + ) async def main(self): devices = self.lsys.devices @@ -50,21 +74,40 @@ async def main(self): for device in devices: worker = self.lsys.create_worker(f"device-{device.name}") scope = self.lsys.create_scope(worker, devices=[device]) - initializers.append(self.initialize(scope)) + initializers.append(self.start_scope(scope)) # Run all initializers in parallel. These launch inference processes. + print("Waiting for initializers") await asyncio.gather(*initializers) # Wait for inference processes to end. + print(f"Running {len(self.processes)} inference processes") await asyncio.gather(*self.processes) + print("Inference processors completed") + + +def run_cli(home_dir: Path, argv): + def client(): + # Create a random image. + print("Preparing requests...") + writer = main.request_queue.writer() + + # Dumb way to prepare some data to feed [1, 3, 224, 224] f32. + import array + + dummy_data = array.array("f", [0.2] * (3 * 224 * 224)) + message = InferenceRequest(dummy_data) + writer(message) + # Done. + writer.close() -def run_server(home_dir: Path): lsys = sf.host.CPUSystemBuilder().create_system() main = Main(lsys, home_dir) + lsys.init_worker.call_threadsafe(client) lsys.run(main.main()) if __name__ == "__main__": home_dir = Path(__file__).resolve().parent - run_server(home_dir) + run_cli(home_dir, sys.argv[1:]) diff --git a/libshortfin/src/shortfin/array/storage.h b/libshortfin/src/shortfin/array/storage.h index 10d0313fe..644f865ac 100644 --- a/libshortfin/src/shortfin/array/storage.h +++ b/libshortfin/src/shortfin/array/storage.h @@ -58,9 +58,19 @@ class SHORTFIN_API storage { private: storage(local::ScopedDevice device, iree::hal_buffer_ptr buffer, local::detail::TimelineResource::Ref timeline_resource) - : buffer_(std::move(buffer)), + : hal_device_ownership_baton_(iree::hal_device_ptr::borrow_reference( + device.raw_device()->hal_device())), + buffer_(std::move(buffer)), device_(device), timeline_resource_(std::move(timeline_resource)) {} + // TODO(ownership): Since storage is a free-standing object in the system, + // it needs an ownership baton that keeps the device/driver alive. Otherwise, + // it can outlive the backing device and then then crashes on buffer + // deallocation. For now, we stash an RAII hal_device_ptr, which keeps + // everything alive. This isn't quite what we want but keeps us going for now. + // When fixing, add a test that creates an array, destroys the System, and + // then frees the array. + iree::hal_device_ptr hal_device_ownership_baton_; iree::hal_buffer_ptr buffer_; local::ScopedDevice device_; local::detail::TimelineResource::Ref timeline_resource_; diff --git a/libshortfin/src/shortfin/local/process.cc b/libshortfin/src/shortfin/local/process.cc index 4fd395368..b40b8ce87 100644 --- a/libshortfin/src/shortfin/local/process.cc +++ b/libshortfin/src/shortfin/local/process.cc @@ -54,10 +54,7 @@ void detail::BaseProcess::Launch() { ScheduleOnWorker(); } -void detail::BaseProcess::ScheduleOnWorker() { - logging::info("ScheduleOnWorker()"); - Terminate(); -} +void detail::BaseProcess::ScheduleOnWorker() { Terminate(); } void detail::BaseProcess::Terminate() { int deallocate_pid; diff --git a/libshortfin/src/shortfin/local/worker.h b/libshortfin/src/shortfin/local/worker.h index 585e92d90..52f5e5948 100644 --- a/libshortfin/src/shortfin/local/worker.h +++ b/libshortfin/src/shortfin/local/worker.h @@ -73,6 +73,8 @@ class SHORTFIN_API Worker { Worker(Options options); Worker(const Worker &) = delete; + Worker &operator=(const Worker &) = delete; + Worker(Worker &&) = delete; ~Worker(); const Options &options() const { return options_; } From b4317071ddab4615822acd44e10f8953d99893e3 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 23 Aug 2024 11:36:10 +0200 Subject: [PATCH 04/20] [libshortfin] Introduce `SHORTFIN_SYSTEMS_AMDGPU` (#137) Introduces an `SHORTFIN_SYSTEMS_AMDGPU` option to allow building libshortfin for host-only systems. The option defaults to ON. --- .../workflows/ci_linux_x64-libshortfin.yml | 26 ++++++++++++++----- libshortfin/CMakeLists.txt | 13 +++++++++- libshortfin/bindings/python/lib_ext.cc | 6 +++++ libshortfin/src/CMakeLists.txt | 7 ++++- 4 files changed, 44 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 27d63ae71..289a83c67 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -22,7 +22,7 @@ permissions: env: IREE_REPO_DIR: ${{ github.workspace }}/iree - BUILD_DIR: ${{ github.workspace }}/libshortfin/build + LIBSHORTFIN_DIR: ${{ github.workspace }}/libshortfin/ jobs: build-and-test: @@ -86,10 +86,10 @@ jobs: # TODO: Switch to `pip install -r requirements.txt -e libshortfin/`. run: pip install nanobind typing_extensions - - name: Build libshortfin + - name: Build libshortfin (full) run: | - mkdir ${{ env.BUILD_DIR }} - cd ${{ env.BUILD_DIR }} + mkdir ${{ env.LIBSHORTFIN_DIR }}/build + cd ${{ env.LIBSHORTFIN_DIR }}/build cmake -GNinja \ -DCMAKE_C_COMPILER=clang-18 \ -DCMAKE_CXX_COMPILER=clang++-18 \ @@ -99,7 +99,21 @@ jobs: .. cmake --build . --target all - - name: Test libshortfin + - name: Test libshortfin (full) run: | - cd ${{ env.BUILD_DIR }} + cd ${{ env.LIBSHORTFIN_DIR }}/build cmake --build . --target test + + - name: Build libshortfin (host-only) + run: | + mkdir ${{ env.LIBSHORTFIN_DIR }}/build-host-only + cd ${{ env.LIBSHORTFIN_DIR }}/build-host-only + cmake -GNinja \ + -DCMAKE_C_COMPILER=clang-18 \ + -DCMAKE_CXX_COMPILER=clang++-18 \ + -DCMAKE_LINKER_TYPE=LLD \ + -DCMAKE_PREFIX_PATH=${{ env.IREE_REPO_DIR }}/build/lib/cmake/IREE \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ + -DSHORTFIN_HAVE_AMDGPU=OFF \ + .. + cmake --build . --target all diff --git a/libshortfin/CMakeLists.txt b/libshortfin/CMakeLists.txt index 8794a2229..0187b1671 100644 --- a/libshortfin/CMakeLists.txt +++ b/libshortfin/CMakeLists.txt @@ -33,6 +33,7 @@ endif() option(SHORTFIN_BUILD_PYTHON_BINDINGS "Builds Python Bindings" OFF) option(SHORTFIN_BUILD_TESTS "Builds C++ tests" ON) option(SHORTFIN_BUNDLE_DEPS "Download dependencies instead of using system libraries" OFF) + set(SHORTFIN_IREE_SOURCE_DIR "" CACHE FILEPATH "Path to IREE source") # Enabling ASAN. Note that this will work best if building in a completely @@ -45,6 +46,14 @@ if(SHORTFIN_ENABLE_ASAN) add_link_options(-fsanitize=address) endif() +option(SHORTFIN_SYSTEMS_AMDGPU "Builds for AMD GPU systems" ON) +message(STATUS "libshortfin supported systems:") +if(SHORTFIN_SYSTEMS_AMDGPU) + message(STATUS " - AMD GPU") + add_compile_definitions("SHORTFIN_HAVE_AMDGPU") +endif() +message(STATUS " - Host") + include(FetchContent) # Includes. @@ -120,7 +129,9 @@ if(SHORTFIN_IREE_SOURCE_DIR) set(IREE_HAL_DRIVER_DEFAULTS OFF) set(IREE_HAL_DRIVER_LOCAL_SYNC ON) set(IREE_HAL_DRIVER_LOCAL_TASK ON) - set(IREE_HAL_DRIVER_HIP ON) + if(SHORTFIN_SYSTEMS_AMDGPU) + set(IREE_HAL_DRIVER_HIP ON) + endif() add_subdirectory(${SHORTFIN_IREE_SOURCE_DIR} shortfin_iree SYSTEM EXCLUDE_FROM_ALL) else() # Try to find iree using find_package diff --git a/libshortfin/bindings/python/lib_ext.cc b/libshortfin/bindings/python/lib_ext.cc index 58d20532a..6072caa04 100644 --- a/libshortfin/bindings/python/lib_ext.cc +++ b/libshortfin/bindings/python/lib_ext.cc @@ -13,7 +13,9 @@ #include "shortfin/local/program.h" #include "shortfin/local/scope.h" #include "shortfin/local/system.h" +#if defined(SHORTFIN_HAVE_AMDGPU) #include "shortfin/local/systems/amdgpu.h" +#endif // SHORTFIN_HAVE_AMDGPU #include "shortfin/local/systems/host.h" #include "shortfin/support/globals.h" #include "shortfin/support/logging.h" @@ -239,7 +241,9 @@ NB_MODULE(lib, m) { auto local_m = m.def_submodule("local"); BindLocal(local_m); BindHostSystem(local_m); +#if defined(SHORTFIN_HAVE_AMDGPU) BindAMDGPUSystem(local_m); +#endif // SHORTFIN_HAVE_AMDGPU auto array_m = m.def_submodule("array"); BindArray(array_m); @@ -712,6 +716,7 @@ void BindHostSystem(py::module_ &global_m) { py::class_(m, "HostCPUDevice"); } +#if defined(SHORTFIN_HAVE_AMDGPU) void BindAMDGPUSystem(py::module_ &global_m) { auto m = global_m.def_submodule("amdgpu", "AMDGPU system config"); py::class_(m, "AMDGPUDevice"); } +#endif // SHORTFIN_HAVE_AMDGPU } // namespace shortfin::python diff --git a/libshortfin/src/CMakeLists.txt b/libshortfin/src/CMakeLists.txt index 1a69094e0..0d6d7ace7 100644 --- a/libshortfin/src/CMakeLists.txt +++ b/libshortfin/src/CMakeLists.txt @@ -13,6 +13,11 @@ target_include_directories( $) +set(_INIT_INTERNAL_DEPS) +if(SHORTFIN_SYSTEMS_AMDGPU) + list(APPEND _INIT_INTERNAL_DEPS shortfin_systems_amdgpu) +endif() + shortfin_public_library( NAME shortfin @@ -20,6 +25,6 @@ shortfin_public_library( shortfin_array shortfin_local shortfin_support - shortfin_systems_amdgpu shortfin_systems_host + ${_INIT_INTERNAL_DEPS} ) From 925140c5bba27dcf78028b004d3d21fe67a177ad Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 23 Aug 2024 11:52:19 +0100 Subject: [PATCH 05/20] [shortfin] Teach setup.py to build using cmake (#139) --- libshortfin/setup.py | 174 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 155 insertions(+), 19 deletions(-) diff --git a/libshortfin/setup.py b/libshortfin/setup.py index f3fc4e9b9..4f4074d2f 100644 --- a/libshortfin/setup.py +++ b/libshortfin/setup.py @@ -5,9 +5,14 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from distutils.core import setup, Extension +import sys +import shutil +import subprocess import os from pathlib import Path +from distutils.command.build import build as _build from setuptools.command.build_ext import build_ext as _build_ext +from setuptools.command.build_py import build_py as _build_py # This file can be generated into the build directory to allow an arbitrary @@ -18,35 +23,45 @@ CPP_PREBUILT_SOURCE_DIR = "@libshortfin_SOURCE_DIR@" CPP_PREBUILT_BINARY_DIR = "@libshortfin_BINARY_DIR@" +SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__)) + def is_cpp_prebuilt(): return CPP_PREBUILT == "TRUE" -def native_build(): - if is_cpp_prebuilt(): - print("setup.py running in pre-built mode from:") - print(f" SOURCE_DIR = {CPP_PREBUILT_SOURCE_DIR}") - print(f" BINARY_DIR = {CPP_PREBUILT_BINARY_DIR}") - return Path(CPP_PREBUILT_SOURCE_DIR), Path(CPP_PREBUILT_BINARY_DIR) - raise RuntimeError("Packaging currently only supported in pre-built mode") - +if is_cpp_prebuilt(): + print("setup.py running in pre-built mode:", file=sys.stderr) + SOURCE_DIR = Path(CPP_PREBUILT_SOURCE_DIR) + BINARY_DIR = Path(CPP_PREBUILT_BINARY_DIR) +else: + print("setup.py running in cmake build mode:", file=sys.stderr) + # setup.py is in the source directory. + SOURCE_DIR = Path(SETUPPY_DIR) + BINARY_DIR = Path(os.path.join(SETUPPY_DIR, "build", "b")) -source_dir, binary_dir = native_build() +print(f" SOURCE_DIR = {SOURCE_DIR}", file=sys.stderr) +print(f" BINARY_DIR = {BINARY_DIR}", file=sys.stderr) # Due to a quirk of setuptools, that package_dir map must only contain # paths relative to the directory containing setup.py. Why? No one knows. -current_dir = Path(__file__).resolve().parent -rel_source_dir = source_dir.relative_to(current_dir, walk_up=True) -rel_binary_dir = binary_dir.relative_to(current_dir, walk_up=True) +REL_SOURCE_DIR = SOURCE_DIR.relative_to(SETUPPY_DIR, walk_up=True) +REL_BINARY_DIR = BINARY_DIR.relative_to(SETUPPY_DIR, walk_up=True) -class BuiltExtension(Extension): +class CMakeExtension(Extension): def __init__(self, name, sourcedir=""): Extension.__init__(self, name, sources=[]) self.sourcedir = os.path.abspath(sourcedir) +class CustomBuild(_build): + def run(self): + self.run_command("build_py") + self.run_command("build_ext") + self.run_command("build_scripts") + + class NoopBuildExtension(_build_ext): def build_extension(self, ext): ... @@ -55,8 +70,127 @@ def copy_extensions_to_source(self, *args, **kwargs): ... -python_src_dir = rel_source_dir / "bindings" / "python" -python_bin_dir = rel_binary_dir / "bindings" / "python" +def maybe_nuke_cmake_cache(cmake_build_dir): + # From run to run under pip, we can end up with different paths to ninja, + # which isn't great and will confuse cmake. Detect if the location of + # ninja changes and force a cache flush. + ninja_path = "" + try: + import ninja + except ModuleNotFoundError: + pass + else: + ninja_path = ninja.__file__ + expected_stamp_contents = f"{sys.executable}\n{ninja_path}" + + # In order to speed things up on CI and not rebuild everything, we nuke + # the CMakeCache.txt file if the path to the Python interpreter changed. + # Ideally, CMake would let us reconfigure this dynamically... but it does + # not (and gets very confused). + PYTHON_STAMP_FILE = os.path.join(cmake_build_dir, "python_stamp.txt") + if os.path.exists(PYTHON_STAMP_FILE): + with open(PYTHON_STAMP_FILE, "rt") as f: + actual_stamp_contents = f.read() + if actual_stamp_contents == expected_stamp_contents: + # All good. + return + + # Mismatch or not found. Clean it. + cmake_cache_file = os.path.join(cmake_build_dir, "CMakeCache.txt") + if os.path.exists(cmake_cache_file): + print("Removing CMakeCache.txt because Python version changed", file=sys.stderr) + os.remove(cmake_cache_file) + + # And write. + with open(PYTHON_STAMP_FILE, "wt") as f: + f.write(expected_stamp_contents) + + +class CMakeBuildPy(_build_py): + def run(self): + # The super-class handles the pure python build. + super().run() + + # Build using cmake if not in prebuild mode. + if not is_cpp_prebuilt(): + + # Build extension using cmake. + print("*****************************", file=sys.stderr) + print("* Building libshortfin *", file=sys.stderr) + print("*****************************", file=sys.stderr) + + cfg = os.getenv("SHORTFIN_CMAKE_BUILD_TYPE", "Release") + + CMAKE_BUILD_DIR = BINARY_DIR + + # Configure CMake. + os.makedirs(BINARY_DIR, exist_ok=True) + maybe_nuke_cmake_cache(CMAKE_BUILD_DIR) + print(f"CMake build dir: {CMAKE_BUILD_DIR}", file=sys.stderr) + cmake_args = [ + "-GNinja", + "--log-level=VERBOSE", + "-DSHORTFIN_BUNDLE_DEPS=ON", + f"-DCMAKE_BUILD_TYPE={cfg}", + "-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON", + # TODO: This shouldn't be hardcoded... but shortfin doesn't + # compile without it. + "-DCMAKE_C_COMPILER=clang", + "-DCMAKE_CXX_COMPILER=clang++", + ] + + # Only do a from-scratch configure if not already configured. + cmake_cache_file = os.path.join(CMAKE_BUILD_DIR, "CMakeCache.txt") + if not os.path.exists(cmake_cache_file): + print(f"Configuring with: {cmake_args}", file=sys.stderr) + subprocess.check_call( + ["cmake", SOURCE_DIR] + cmake_args, cwd=CMAKE_BUILD_DIR + ) + else: + print(f"Not re-configing (already configured)", file=sys.stderr) + + # Build. + subprocess.check_call(["cmake", "--build", "."], cwd=CMAKE_BUILD_DIR) + print("Build complete.", file=sys.stderr) + + # We only take _shortfin_default from the build. + target_dir = os.path.join( + os.path.abspath(self.build_lib), "_shortfin_default" + ) + print(f"Building in target: {target_dir}", file=sys.stderr) + os.makedirs(target_dir, exist_ok=True) + print("Copying build to target.", file=sys.stderr) + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + shutil.copytree( + os.path.join( + CMAKE_BUILD_DIR, + "bindings", + "python", + "_shortfin_default", + ), + target_dir, + symlinks=False, + ) + + +PYTHON_SOURCE_DIR = REL_SOURCE_DIR / "bindings" / "python" +PYTHON_BINARY_DIR = REL_BINARY_DIR / "bindings" / "python" + +# We need some directories to exist before setup. +def populate_built_package(abs_dir): + """Makes sure that a directory and __init__.py exist. + + This needs to unfortunately happen before any of the build process + takes place so that setuptools can plan what needs to be built. + We do this for any built packages (vs pure source packages). + """ + os.makedirs(abs_dir, exist_ok=True) + with open(os.path.join(abs_dir, "__init__.py"), "wt"): + pass + + +populate_built_package(os.path.join(PYTHON_BINARY_DIR / "_shortfin_default")) setup( name="shortfin", @@ -71,16 +205,18 @@ def copy_extensions_to_source(self, *args, **kwargs): ], zip_safe=False, package_dir={ - "_shortfin": str(python_src_dir / "_shortfin"), - "_shortfin_default": str(python_bin_dir / "_shortfin_default"), + "_shortfin": str(PYTHON_SOURCE_DIR / "_shortfin"), + "_shortfin_default": str(PYTHON_BINARY_DIR / "_shortfin_default"), # TODO: Conditionally map additional native library variants. - "shortfin": str(python_src_dir / "shortfin"), + "shortfin": str(PYTHON_SOURCE_DIR / "shortfin"), }, ext_modules=[ - BuiltExtension("_shortfin_default.lib"), + CMakeExtension("_shortfin_default.lib") # TODO: Conditionally map additional native library variants. ], cmdclass={ + "build": CustomBuild, "build_ext": NoopBuildExtension, + "build_py": CMakeBuildPy, }, ) From 889cada3c0aabd2d02bf0f21ed115020dff82884 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 23 Aug 2024 15:52:13 +0200 Subject: [PATCH 06/20] Execute Python tests in CI (#146) With the Python tests are executed in the CI via pytest. Tests that are not capable to run due to missing hardware are not executed whereas failing test that are expected to pass in the future are currently marked xfail. --- .github/workflows/ci_linux_x64-libshortfin.yml | 8 +++++--- libshortfin/tests/amdgpu_system_test.py | 3 +++ libshortfin/tests/examples/fastapi_test.py | 3 +++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 289a83c67..796beb9e6 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -1,4 +1,3 @@ -#!/bin/bash # Copyright 2024 Advanced Micro Devices, Inc # # Licensed under the Apache License v2.0 with LLVM Exceptions. @@ -80,11 +79,11 @@ jobs: - name: Setup Python uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1 with: - python-version: "3.11" + python-version: "3.12" cache: "pip" - name: Install Python packages # TODO: Switch to `pip install -r requirements.txt -e libshortfin/`. - run: pip install nanobind typing_extensions + run: pip install nanobind pytest requests - name: Build libshortfin (full) run: | @@ -98,11 +97,14 @@ jobs: -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ .. cmake --build . --target all + pip install -v -e . - name: Test libshortfin (full) run: | cd ${{ env.LIBSHORTFIN_DIR }}/build cmake --build . --target test + cd ${{ env.LIBSHORTFIN_DIR }} + pytest -m "not requires_amd_gpu" - name: Build libshortfin (host-only) run: | diff --git a/libshortfin/tests/amdgpu_system_test.py b/libshortfin/tests/amdgpu_system_test.py index 74ea69af2..4c6d1fae0 100644 --- a/libshortfin/tests/amdgpu_system_test.py +++ b/libshortfin/tests/amdgpu_system_test.py @@ -4,7 +4,10 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest + +@pytest.mark.requires_amd_gpu def test_create_host_cpu_system(): from _shortfin import lib as sfl diff --git a/libshortfin/tests/examples/fastapi_test.py b/libshortfin/tests/examples/fastapi_test.py index f19c1c12f..f34cf7e24 100644 --- a/libshortfin/tests/examples/fastapi_test.py +++ b/libshortfin/tests/examples/fastapi_test.py @@ -29,11 +29,13 @@ def server(): # Test error first to make sure it doesn't mess up the server. +@pytest.mark.xfail(raises=RuntimeError, reason="Failing (but should work)") def test_error_response(server): resp = requests.get(f"{server.url}/predict?value=0") assert resp.status_code == 500 +@pytest.mark.xfail(raises=RuntimeError, reason="Failing (but should work)") def test_single_response(server): resp = requests.get(f"{server.url}/predict?value=1") resp.raise_for_status() @@ -42,6 +44,7 @@ def test_single_response(server): assert full_contents == b'{"answer":1}' +@pytest.mark.xfail(raises=RuntimeError, reason="Failing (but should work)") def test_stream_response(server): resp = requests.get(f"{server.url}/predict?value=20") resp.raise_for_status() From cbf739c078eb3cc981d8bcc63a3137fd55d06385 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 23 Aug 2024 15:54:08 +0200 Subject: [PATCH 07/20] [libshortfin] Shallow clone git repositories (#147) Only get shallow copies of all repositories cloned as bundeled dep or in the CI. For `actions/checkout` the fetch depth defaults to `1`. --- .github/workflows/ci_linux_x64-libshortfin.yml | 11 +++++------ libshortfin/CMakeLists.txt | 5 +++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 796beb9e6..34b385f6e 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -46,16 +46,15 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - depth: 1 - name: Initalize IREE submodules run : | cd ${{ env.IREE_REPO_DIR }} - git submodule update --init -- third_party/benchmark - git submodule update --init -- third_party/cpuinfo/ - git submodule update --init -- third_party/flatcc - git submodule update --init -- third_party/googletest - git submodule update --init -- third_party/hip-build-deps/ + git submodule update --init --depth 1 -- third_party/benchmark + git submodule update --init --depth 1 -- third_party/cpuinfo/ + git submodule update --init --depth 1 -- third_party/flatcc + git submodule update --init --depth 1 -- third_party/googletest + git submodule update --init --depth 1 -- third_party/hip-build-deps/ - name: Build IREE runtime run: | diff --git a/libshortfin/CMakeLists.txt b/libshortfin/CMakeLists.txt index 0187b1671..a81229d23 100644 --- a/libshortfin/CMakeLists.txt +++ b/libshortfin/CMakeLists.txt @@ -70,6 +70,7 @@ if(SHORTFIN_BUNDLE_DEPS) fmt GIT_REPOSITORY https://github.com/fmtlib/fmt.git GIT_TAG e69e5f977d458f2650bb346dadf2ad30c5320281 # 10.2.1 (sync with spdlog) + GIT_SHALLOW TRUE ) ## spdlog @@ -79,6 +80,7 @@ if(SHORTFIN_BUNDLE_DEPS) spdlog GIT_REPOSITORY https://github.com/gabime/spdlog.git GIT_TAG 2d4acf8cc321d7783d8f2e22e17a794c6d0e9450 # v1.14.1 + GIT_SHALLOW TRUE ) ## xtl: required for xtensor @@ -86,6 +88,7 @@ if(SHORTFIN_BUNDLE_DEPS) xtl GIT_REPOSITORY https://github.com/xtensor-stack/xtl.git GIT_TAG a7c1c5444dfc57f76620391af4c94785ff82c8d6 # v0.7.7 + GIT_SHALLOW TRUE ) ## xtensor @@ -93,6 +96,7 @@ if(SHORTFIN_BUNDLE_DEPS) xtensor GIT_REPOSITORY https://github.com/xtensor-stack/xtensor.git GIT_TAG 3634f2ded19e0cf38208c8b86cea9e1d7c8e397d # v0.25.0 + GIT_SHALLOW TRUE ) FetchContent_MakeAvailable(fmt spdlog xtl xtensor) @@ -111,6 +115,7 @@ if (NOT SHORTFIN_IREE_SOURCE_DIR AND SHORTFIN_BUNDLE_DEPS) # TODO: We shouldn't have to pull googletest when we are not building tests. # This needs to be fixed with IREE. GIT_SUBMODULES "third_party/benchmark third_party/cpuinfo third_party/flatcc third_party/hip-build-deps third_party/googletest" + GIT_SHALLOW TRUE ) FetchContent_GetProperties(iree) if(NOT iree_POPULATED) From a4e1ff22fdd4ad4bda098fea9f77a08e26911126 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 23 Aug 2024 19:08:39 +0200 Subject: [PATCH 08/20] [libshortfin] Set a soversion on `libshortfin` (#138) --- libshortfin/CMakeLists.txt | 2 ++ libshortfin/src/CMakeLists.txt | 2 ++ 2 files changed, 4 insertions(+) diff --git a/libshortfin/CMakeLists.txt b/libshortfin/CMakeLists.txt index a81229d23..64ef168e5 100644 --- a/libshortfin/CMakeLists.txt +++ b/libshortfin/CMakeLists.txt @@ -18,6 +18,8 @@ project( VERSION 0.9 LANGUAGES C CXX) +set(SOVERSION 1) + set(CMAKE_C_STANDARD 11) set(CMAKE_CXX_STANDARD 20) # https://discourse.cmake.org/t/cmake-3-28-cmake-cxx-compiler-clang-scan-deps-notfound-not-found/9244/3 diff --git a/libshortfin/src/CMakeLists.txt b/libshortfin/src/CMakeLists.txt index 0d6d7ace7..de31643e3 100644 --- a/libshortfin/src/CMakeLists.txt +++ b/libshortfin/src/CMakeLists.txt @@ -28,3 +28,5 @@ shortfin_public_library( shortfin_systems_host ${_INIT_INTERNAL_DEPS} ) + +set_target_properties(shortfin PROPERTIES VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR} SOVERSION ${SOVERSION}) From 41190384571a5fc4411f02a0dd093270f2837118 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Sat, 24 Aug 2024 12:46:22 +0200 Subject: [PATCH 09/20] [libshortfin] Allow tests to pass (#148) Lists and installs the missing test deps and removes the expected fail marking. --- .github/workflows/ci_linux_x64-libshortfin.yml | 4 +++- libshortfin/pyproject.toml | 3 +++ libshortfin/requirements-tests.txt | 4 ++++ libshortfin/tests/examples/fastapi_test.py | 3 --- 4 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 libshortfin/requirements-tests.txt diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 34b385f6e..233886c97 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -82,7 +82,9 @@ jobs: cache: "pip" - name: Install Python packages # TODO: Switch to `pip install -r requirements.txt -e libshortfin/`. - run: pip install nanobind pytest requests + run: | + pip install nanobind + pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt - name: Build libshortfin (full) run: | diff --git a/libshortfin/pyproject.toml b/libshortfin/pyproject.toml index 5185be707..e868b4264 100644 --- a/libshortfin/pyproject.toml +++ b/libshortfin/pyproject.toml @@ -13,6 +13,9 @@ addopts = [ "-ra", "--import-mode=importlib", ] +markers = [ + "requires_amd_gpu: tests that require and AMD GPU (deselect with '-m \"not requires_amd_gpu\"')", +] testpaths = [ "tests", ] diff --git a/libshortfin/requirements-tests.txt b/libshortfin/requirements-tests.txt new file mode 100644 index 000000000..1049b0412 --- /dev/null +++ b/libshortfin/requirements-tests.txt @@ -0,0 +1,4 @@ +pytest +requests +fastapi +uvicorn diff --git a/libshortfin/tests/examples/fastapi_test.py b/libshortfin/tests/examples/fastapi_test.py index f34cf7e24..f19c1c12f 100644 --- a/libshortfin/tests/examples/fastapi_test.py +++ b/libshortfin/tests/examples/fastapi_test.py @@ -29,13 +29,11 @@ def server(): # Test error first to make sure it doesn't mess up the server. -@pytest.mark.xfail(raises=RuntimeError, reason="Failing (but should work)") def test_error_response(server): resp = requests.get(f"{server.url}/predict?value=0") assert resp.status_code == 500 -@pytest.mark.xfail(raises=RuntimeError, reason="Failing (but should work)") def test_single_response(server): resp = requests.get(f"{server.url}/predict?value=1") resp.raise_for_status() @@ -44,7 +42,6 @@ def test_single_response(server): assert full_contents == b'{"answer":1}' -@pytest.mark.xfail(raises=RuntimeError, reason="Failing (but should work)") def test_stream_response(server): resp = requests.get(f"{server.url}/predict?value=20") resp.raise_for_status() From 040761c8670b50a4524e95d8c6c4f59542a4b141 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Sat, 24 Aug 2024 11:47:00 -0700 Subject: [PATCH 10/20] [quant] Prevent faketensoring shape values (#144) All numpy arrays are automatically converted to fake tensors. To support reshape constants we need to ensure that they are cast to integers first. --- sharktank/sharktank/types/gguf_interop/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sharktank/sharktank/types/gguf_interop/base.py b/sharktank/sharktank/types/gguf_interop/base.py index 315a0cd84..a343e333c 100644 --- a/sharktank/sharktank/types/gguf_interop/base.py +++ b/sharktank/sharktank/types/gguf_interop/base.py @@ -125,9 +125,10 @@ def load_file(gguf_path: Union[str, os.PathLike]) -> Dataset: # Extract tensors. tensors: dict[str, InferenceTensor] = {} for tensor in reader.tensors: + shape = [int(d) for d in tensor.shape] gguf_tensor = _wrap_tensor( name=tensor.name, - logical_shape=list(tensor.shape), + logical_shape=list(shape), type_name=tensor.tensor_type.name, data=tensor.data, # type: ignore ) From 7e7b77ed60ce6de5d1d6a2384ccc8ae21714ea59 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 24 Aug 2024 20:55:46 -0700 Subject: [PATCH 11/20] [shortfin] Build out basic array interop. (#150) * In C++, `device_array` subclasses can now interop with xtensor nd-array views. This is used to make the `to_s` / `__repr__` method pretty print contents for now. * C++ and Python APIs for mapping host memory. * Python `storage.map()` if used for explicit mapping (with arguments to control details) whereas the `storage.data` property maps as read-only on getattr and as write/discard on setattr. Combined with the buffer protocol, this makes mapping work gluelessly with other Python buffer based objects. The xtensor_bridge is relatively complex because it has to bridge from the polymorphic dtype world of base_array to the static element type world of xtensor. It does this by having an array mixin class which lazily constructs an indirection trampoline object that is (private to the xtensor_bridge TU) template instantiated for each supported concrete C++ type. When the trampoline is needed, memory is mapped and the trampoline is constructed, if possible. Because of the type-erasure and desire to minimize per-array overhead for the case when the feature is not used, this involves two layers of in-place construction to arrive at the concrete xtensor adaptor. This is enough to implement basic things like I/O. A bit more will be needed to add capabilities from there. C++ code which knows what types it is operating on can also directly instantiate an adaptor, but even then, it is good to have some generic array support (like I/O, fill, tranpose, etc). This was organized so the feature could be ifdef'd out if not desired. In a pure C++ static library, it should just be dropped by the linker if not used, but it does add compile time and complexity. Current cost is about 175KiB for the xtensor bridge (out of a total library size of 1.5MiB). This will increase somewhat as more features are used. --- .../workflows/ci_linux_x64-libshortfin.yml | 2 +- libshortfin/bindings/python/array_binding.cc | 172 ++++++++++-- libshortfin/bindings/python/shortfin/array.py | 2 - libshortfin/bindings/python/utils.h | 58 +++- libshortfin/src/shortfin/array/CMakeLists.txt | 12 + libshortfin/src/shortfin/array/api.h | 2 + libshortfin/src/shortfin/array/array.cc | 49 +++- libshortfin/src/shortfin/array/array.h | 160 +++++------ libshortfin/src/shortfin/array/array_test.cc | 62 +++++ libshortfin/src/shortfin/array/dims.h | 256 ++++++++++++++++++ libshortfin/src/shortfin/array/dims_test.cc | 148 ++++++++++ libshortfin/src/shortfin/array/dtype_test.cc | 34 +++ libshortfin/src/shortfin/array/storage.cc | 77 +++++- libshortfin/src/shortfin/array/storage.h | 162 ++++++++++- .../src/shortfin/array/xtensor_bridge.cc | 93 +++++++ .../src/shortfin/array/xtensor_bridge.h | 160 +++++++++++ libshortfin/src/shortfin/local/scope.h | 16 +- .../src/shortfin/support/CMakeLists.txt | 2 +- libshortfin/tests/amdgpu_system_test.py | 4 +- libshortfin/tests/array_test.py | 73 ++++- libshortfin/tests/local_scope_test.py | 4 +- libshortfin/tests/smoke_test.py | 11 - 22 files changed, 1380 insertions(+), 179 deletions(-) create mode 100644 libshortfin/src/shortfin/array/array_test.cc create mode 100644 libshortfin/src/shortfin/array/dims.h create mode 100644 libshortfin/src/shortfin/array/dims_test.cc create mode 100644 libshortfin/src/shortfin/array/dtype_test.cc create mode 100644 libshortfin/src/shortfin/array/xtensor_bridge.cc create mode 100644 libshortfin/src/shortfin/array/xtensor_bridge.h delete mode 100644 libshortfin/tests/smoke_test.py diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 233886c97..20f944c5b 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -105,7 +105,7 @@ jobs: cd ${{ env.LIBSHORTFIN_DIR }}/build cmake --build . --target test cd ${{ env.LIBSHORTFIN_DIR }} - pytest -m "not requires_amd_gpu" + pytest -s -v -m "not requires_amd_gpu" - name: Build libshortfin (host-only) run: | diff --git a/libshortfin/bindings/python/array_binding.cc b/libshortfin/bindings/python/array_binding.cc index b7a3fb752..fc4694107 100644 --- a/libshortfin/bindings/python/array_binding.cc +++ b/libshortfin/bindings/python/array_binding.cc @@ -12,6 +12,53 @@ using namespace shortfin::array; namespace shortfin::python { +namespace { +static const char DOCSTRING_STORAGE_DATA[] = R"(Access raw binary contents. + +Accessing `foo = storage.data` is equivalent to `storage.data.map(read=True)`. +The returned object is a context manager that will close on exit. + +Assigning `storage.data = array.array("f", [1.0])` will copy that raw data +from the source object using the buffer protocol. The source data must be +less than or equal to the length of the storage object. Note that the entire +storage is mapped as write-only/discardable, and writing less than the storage +bytes leaves any unwritten contents in an undefined state. + +As with `map`, this will only work on buffers that are host visible, which +includes all host buffers and device buffers created with the necessary access. +)"; + +static const char DOCSTRING_STORAGE_MAP[] = + R"(Create a mapping of the buffer contents in host memory. + +Support kwargs of: + +read: Enables read access to the mapped memory. +write: Enables write access to the mapped memory and will flush upon close + (for non-unified memory systems). +discard: Indicates that the entire memory map should be treated as if it will + be overwritten. Initial contents will be undefined. + +Mapping memory for access from the host requires a compatible buffer that has +been created with host visibility (which includes host buffers). + +The returned mapping object is a context manager that will close/flush on +exit. Alternatively, the `close()` method can be invoked explicitly. +)"; + +// Does in-place creation of a mapping object and stores a pointer to the +// contained array::mapping C++ object. +py::object CreateMappingObject(mapping **out_cpp_mapping) { + py::object py_mapping = py::inst_alloc(py::type()); + mapping *cpp_mapping = py::inst_ptr(py_mapping); + new (cpp_mapping) mapping(); + py::inst_mark_ready(py_mapping); + *out_cpp_mapping = cpp_mapping; + return py_mapping; +} + +} // namespace + void BindArray(py::module_ &m) { py::class_(m, "DType") .def_prop_ro("is_boolean", &DType::is_boolean) @@ -52,6 +99,7 @@ void BindArray(py::module_ &m) { m.attr("complex64") = DType::complex64(); m.attr("complex128") = DType::complex128(); + // storage py::class_(m, "storage") .def_static( "allocate_host", @@ -75,8 +123,82 @@ void BindArray(py::module_ &m) { PyBufferReleaser py_view_releaser(py_view); self.Fill(py_view.buf, py_view.len); }) + .def( + "map", + [](storage &self, bool read, bool write, bool discard) { + int access = 0; + if (read) access |= IREE_HAL_MEMORY_ACCESS_READ; + if (write) access |= IREE_HAL_MEMORY_ACCESS_WRITE; + if (discard) access |= IREE_HAL_MEMORY_ACCESS_DISCARD; + if (!access) { + throw std::invalid_argument( + "One of the access flags must be set"); + } + mapping *cpp_mapping = nullptr; + py::object py_mapping = CreateMappingObject(&cpp_mapping); + self.MapExplicit( + *cpp_mapping, + static_cast(access)); + return py_mapping; + }, + py::kw_only(), py::arg("read") = false, py::arg("write") = false, + py::arg("discard") = false, DOCSTRING_STORAGE_MAP) + // The 'data' prop is a short-hand for accessing the backing storage + // in a one-shot manner (as for reading or writing). Getting the attribute + // will map for read and return a memory view (equiv to map(read=True)). + // On write, it will accept an object implementing the buffer protocol + // and write/discard the backing storage. + .def_prop_rw( + "data", + [](storage &self) { + mapping *cpp_mapping = nullptr; + py::object py_mapping = CreateMappingObject(&cpp_mapping); + *cpp_mapping = self.MapRead(); + return py_mapping; + }, + [](storage &self, py::handle buffer_obj) { + PyBufferRequest src_info(buffer_obj, PyBUF_SIMPLE); + auto dest_data = self.MapWriteDiscard(); + if (src_info.view().len > dest_data.size()) { + throw std::invalid_argument( + fmt::format("Cannot write {} bytes into buffer of {} bytes", + src_info.view().len, dest_data.size())); + } + std::memcpy(dest_data.data(), src_info.view().buf, + src_info.view().len); + }, + DOCSTRING_STORAGE_DATA) .def("__repr__", &storage::to_s); + // mapping + auto mapping_class = py::class_(m, "mapping"); + mapping_class.def("close", &mapping::reset) + .def_prop_ro("valid", [](mapping &self) -> bool { return self; }) + .def("__enter__", [](py::object self_obj) { return self_obj; }) + .def( + "__exit__", + [](mapping &self, py::handle exc_type, py::handle exc_value, + py::handle exc_tb) { self.reset(); }, + py::arg("exc_type").none(), py::arg("exc_value").none(), + py::arg("exc_tb").none()); + struct MappingBufferHandler { + int operator()(mapping &self, Py_buffer *view, int flags) { + view->buf = self.data(); + view->len = self.size(); + view->readonly = self.writable(); + view->itemsize = 1; + view->format = (char *)"B"; // Byte + view->ndim = 1; + view->shape = nullptr; + view->strides = nullptr; + view->suboffsets = nullptr; + view->internal = nullptr; + return 0; + } + }; + BindBufferProtocol(mapping_class); + + // base_array and subclasses py::class_(m, "base_array") .def_prop_ro("dtype", &base_array::dtype) .def_prop_ro("shape", &base_array::shape); @@ -94,40 +216,34 @@ void BindArray(py::module_ &m) { std::span shape, DType dtype) { return custom_new_keep_alive( py_type, /*keep_alive=*/device.scope(), - device_array::allocate(device, shape, dtype)); + device_array::for_device(device, shape, dtype)); }) - .def_prop_ro("device", &device_array::device, - py::rv_policy::reference_internal) - .def_prop_ro("storage", &device_array::storage, - py::rv_policy::reference_internal) - .def("__repr__", &device_array::to_s); - py::class_(m, "host_array") - .def("__init__", [](py::args, py::kwargs) {}) - .def_static("__new__", - [](py::handle py_type, class storage storage, - std::span shape, DType dtype) { - return custom_new_keep_alive( - py_type, /*keep_alive=*/storage.scope(), storage, shape, - dtype); + .def_static("for_device", + [](local::ScopedDevice &device, std::span shape, + DType dtype) { + return custom_new_keep_alive( + py::type(), /*keep_alive=*/device.scope(), + device_array::for_device(device, shape, dtype)); }) - .def_static("__new__", - [](py::handle py_type, local::ScopedDevice &device, - std::span shape, DType dtype) { - return custom_new_keep_alive( - py_type, /*keep_alive=*/device.scope(), - host_array::allocate(device, shape, dtype)); + .def_static("for_host", + [](local::ScopedDevice &device, std::span shape, + DType dtype) { + return custom_new_keep_alive( + py::type(), /*keep_alive=*/device.scope(), + device_array::for_host(device, shape, dtype)); }) - .def_static("__new__", - [](py::handle py_type, device_array &device_array) { - return custom_new_keep_alive( - py_type, /*keep_alive=*/device_array.device().scope(), - host_array::for_transfer(device_array)); + .def_static("for_transfer", + [](device_array &existing) { + return custom_new_keep_alive( + py::type(), + /*keep_alive=*/existing.device().scope(), + device_array::for_transfer(existing)); }) - .def_prop_ro("device", &host_array::device, + .def_prop_ro("device", &device_array::device, py::rv_policy::reference_internal) - .def_prop_ro("storage", &host_array::storage, + .def_prop_ro("storage", &device_array::storage, py::rv_policy::reference_internal) - .def("__repr__", &host_array::to_s); + .def("__repr__", &device_array::to_s); } } // namespace shortfin::python diff --git a/libshortfin/bindings/python/shortfin/array.py b/libshortfin/bindings/python/shortfin/array.py index e99595554..049fe9ed7 100644 --- a/libshortfin/bindings/python/shortfin/array.py +++ b/libshortfin/bindings/python/shortfin/array.py @@ -37,7 +37,6 @@ base_array = _sfl.array.base_array device_array = _sfl.array.device_array -host_array = _sfl.array.host_array storage = _sfl.array.storage DType = _sfl.array.DType @@ -73,7 +72,6 @@ # Classes. "base_array", "device_array", - "host_array", "storage", "DType", ] diff --git a/libshortfin/bindings/python/utils.h b/libshortfin/bindings/python/utils.h index 24e2a6642..dab4423f8 100644 --- a/libshortfin/bindings/python/utils.h +++ b/libshortfin/bindings/python/utils.h @@ -14,10 +14,10 @@ namespace shortfin::python { // Casts any of int, str, local::Device, DeviceAffinity to a DeviceAffinity. // If the object is a sequence, then the affinity is constructed from the union. -inline local::ScopedDevice CastDeviceAffinity(local::Scope &scope, +inline local::ScopedDevice CastDeviceAffinity(local::Scope& scope, py::handle object) { if (py::isinstance(object)) { - return scope.device(py::cast(object)); + return scope.device(py::cast(object)); } else if (py::isinstance(object)) { return local::ScopedDevice(scope, py::cast(object)); } else if (py::isinstance(object)) { @@ -39,4 +39,58 @@ inline local::ScopedDevice CastDeviceAffinity(local::Scope &scope, py::repr(object).c_str())); } +// For a bound class, binds the buffer protocol. This will result in a call +// to handler like: +// HandlerFunctor(self, Py_buffer *view, int flags) +// This is a low level callback and must not raise any exceptions. If +// error conditions are warranted the usual PyErr_SetString approach must be +// used (and -1 returned). Return 0 on success. +template +void BindBufferProtocol(py::handle clazz) { + PyBufferProcs buffer_procs; + memset(&buffer_procs, 0, sizeof(buffer_procs)); + buffer_procs.bf_getbuffer = + // It is not legal to raise exceptions from these callbacks. + +[](PyObject* raw_self, Py_buffer* view, int flags) noexcept -> int { + if (view == NULL) { + PyErr_SetString(PyExc_ValueError, "NULL view in getbuffer"); + return -1; + } + + // Cast must succeed due to invariants. + auto& self = py::cast(py::handle(raw_self)); + + Py_INCREF(raw_self); + view->obj = raw_self; + HandlerFunctor handler; + return handler(self, view, flags); + }; + buffer_procs.bf_releasebuffer = + +[](PyObject* raw_self, Py_buffer* view) noexcept -> void {}; + auto heap_type = reinterpret_cast(clazz.ptr()); + assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && + "must be heap type"); + heap_type->as_buffer = buffer_procs; +} + +// Represents a Py_buffer obtained via PyObject_GetBuffer() and terminated via +// PyBuffer_Release(). +class PyBufferRequest { + public: + PyBufferRequest(py::handle& exporter, int flags) { + int rc = PyObject_GetBuffer(exporter.ptr(), &view_, flags); + if (rc != 0) { + throw py::python_error(); + } + } + ~PyBufferRequest() { PyBuffer_Release(&view_); } + PyBufferRequest(const PyBufferRequest&) = delete; + void operator=(const PyBufferRequest&) = delete; + + Py_buffer& view() { return view_; } + + private: + Py_buffer view_; +}; + } // namespace shortfin::python diff --git a/libshortfin/src/shortfin/array/CMakeLists.txt b/libshortfin/src/shortfin/array/CMakeLists.txt index da22e3cc0..0e9360363 100644 --- a/libshortfin/src/shortfin/array/CMakeLists.txt +++ b/libshortfin/src/shortfin/array/CMakeLists.txt @@ -10,13 +10,25 @@ shortfin_cc_component( HDRS array.h api.h + dims.h dtype.h storage.h SRCS array.cc dtype.cc storage.cc + xtensor_bridge.cc COMPONENTS shortfin_local shortfin_support + DEPS + xtensor +) + +shortfin_gtest_test( + NAME shortfin_array_test + SRCS + array_test.cc + dims_test.cc + dtype_test.cc ) diff --git a/libshortfin/src/shortfin/array/api.h b/libshortfin/src/shortfin/array/api.h index e7f73ede4..baa8a55ea 100644 --- a/libshortfin/src/shortfin/array/api.h +++ b/libshortfin/src/shortfin/array/api.h @@ -8,7 +8,9 @@ #define SHORTFIN_ARRAY_API_H #include "shortfin/array/array.h" +#include "shortfin/array/dims.h" #include "shortfin/array/dtype.h" #include "shortfin/array/storage.h" +#include "shortfin/array/xtensor_bridge.h" #endif // SHORTFIN_ARRAY_API_H diff --git a/libshortfin/src/shortfin/array/array.cc b/libshortfin/src/shortfin/array/array.cc index 74d20e47e..1d6d7cc5a 100644 --- a/libshortfin/src/shortfin/array/array.cc +++ b/libshortfin/src/shortfin/array/array.cc @@ -6,29 +6,56 @@ #include "shortfin/array/array.h" +#include + #include "fmt/core.h" #include "fmt/ranges.h" +#include "shortfin/array/xtensor_bridge.h" namespace shortfin::array { +template class InlinedDims; + // -------------------------------------------------------------------------- // // device_array // -------------------------------------------------------------------------- // -std::string device_array::to_s() const { - return fmt::format("device_array([{}], dtype='{}', {})", - fmt::join(shape(), ", "), dtype().name(), - storage_.device().to_s()); -} +const mapping device_array::data() const { return storage_.MapRead(); } -// -------------------------------------------------------------------------- // -// host_array -// -------------------------------------------------------------------------- // +mapping device_array::data() { return storage_.MapRead(); } + +mapping device_array::data_rw() { return storage_.MapReadWrite(); } + +mapping device_array::data_w() { return storage_.MapWriteDiscard(); } -std::string host_array::to_s() const { - return fmt::format("host_array([{}], dtype='{}', {})", +std::optional device_array::map_memory_for_xtensor() { + if (storage_.is_mappable_for_read_write()) { + return storage_.MapReadWrite(); + } else if (storage_.is_mappable_for_read()) { + return storage_.MapRead(); + } + return {}; +} + +std::string device_array::to_s() const { + std::string contents; + const char *contents_prefix = " "; + if (!storage_.is_mappable_for_read()) { + contents = ""; + } else { + auto maybe_contents = contents_to_s(); + if (maybe_contents) { + contents = std::move(*maybe_contents); + contents_prefix = "\n"; + } else { + contents = ""; + } + } + + return fmt::format("device_array([{}], dtype='{}', device={}({})) ={}{}", fmt::join(shape(), ", "), dtype().name(), - storage_.device().to_s()); + storage_.device().to_s(), storage_.formatted_memory_type(), + contents_prefix, contents); } } // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/array.h b/libshortfin/src/shortfin/array/array.h index 1e0ea80d2..31deb665a 100644 --- a/libshortfin/src/shortfin/array/array.h +++ b/libshortfin/src/shortfin/array/array.h @@ -12,8 +12,10 @@ #include #include +#include "shortfin/array/dims.h" #include "shortfin/array/dtype.h" #include "shortfin/array/storage.h" +#include "shortfin/array/xtensor_bridge.h" #include "shortfin/support/api.h" namespace shortfin::array { @@ -28,129 +30,99 @@ class SHORTFIN_API base_array { // a value type because the Dims union is otherwise not copy/movable. base_array(const base_array &other) : base_array(other.shape(), other.dtype()) {} - base_array(base_array &&other) : rank_(other.rank_), dtype_(other.dtype_) { - // Custom move the dims to avoid an additional allocation. This could just - // be a memcpy on most impls, but this is the "right way". - if (rank_ > MAX_INLINE_RANK) { - // Dynamic allocation. - new (&shape_.dynamic_dims) Dims(); - shape_.dynamic_dims = std::move(other.shape_.dynamic_dims); - } else { - // Inline allocation. - new (&shape_.inline_dims) Dims(); - shape_.inline_dims = other.shape_.inline_dims; - } - other.rank_ = 0; - } - virtual ~base_array() { ClearDims(); } + base_array(base_array &&other) + : dtype_(other.dtype_), shape_(std::move(other.shape_)) {} + virtual ~base_array() = default; + virtual std::string to_s() const = 0; DType dtype() const { return dtype_; } // Access shape. - void set_shape(std::span shape) { - ClearDims(); - rank_ = shape.size(); - if (rank_ > MAX_INLINE_RANK) { - // Dynamic allocation. - new (&shape_.dynamic_dims) std::unique_ptr(new size_t[rank_]); - std::copy(shape.begin(), shape.end(), shape_.dynamic_dims.get()); - } else { - // Inline allocation. - new (&shape_.inline_dims) Dims(); - std::copy(shape.begin(), shape.end(), shape_.inline_dims.begin()); - } - } - std::span shape() const { - if (rank_ > MAX_INLINE_RANK) { - // Dynamic allocation. - return std::span(shape_.dynamic_dims.get(), rank_); - } else { - // Inline allocation. - return std::span(&shape_.inline_dims.front(), rank_); - } - } - std::span mutable_shape() { - if (rank_ > MAX_INLINE_RANK) { - // Dynamic allocation. - return std::span(shape_.dynamic_dims.get(), rank_); - } else { - // Inline allocation. - return std::span(&shape_.inline_dims.front(), rank_); - } - } + void set_shape(std::span shape) { shape_.set(shape); } + std::span shape() const { return shape_.span(); } + std::span mutable_shape() { return shape_.span(); } - private: - static constexpr size_t MAX_INLINE_RANK = 6; - union Dims { - Dims() {} - ~Dims() {} - std::array inline_dims; - std::unique_ptr dynamic_dims; - }; - - // Clears shape, setting the rank to zero and deleting any non-inline - // dimension storage. - void ClearDims() { - if (rank_ > MAX_INLINE_RANK) { - shape_.dynamic_dims.~unique_ptr(); - } - rank_ = 0; - } + // Sometimes we need to access the raw shape container (i.e. for adapters, + // etc). + Dims &shape_container() { return shape_; } + const Dims &shape_container() const { return shape_; } - size_t rank_ = 0; + private: DType dtype_; Dims shape_; }; -// View over some device allocation, modeled as a dense C-order nd array. -class SHORTFIN_API device_array final : public base_array { +class SHORTFIN_API device_array + : public base_array, + public poly_xt_mixin { public: device_array(class storage storage, std::span shape, DType dtype) : base_array(shape, dtype), storage_(std::move(storage)) {} - static device_array allocate(local::ScopedDevice &device, - std::span shape, DType dtype) { + class storage &storage() { return storage_; } + local::ScopedDevice &device() { return storage_.device(); } + + // Allocate an array on the device. + static device_array for_device(local::ScopedDevice &device, + std::span shape, DType dtype) { return device_array( storage::AllocateDevice(device, dtype.compute_dense_nd_size(shape)), shape, dtype); } - class storage &storage() { return storage_; } - local::ScopedDevice &device() { return storage_.device(); } - std::string to_s() const; - - private: - class storage storage_; -}; - -// View over some host allocation, registered for transfer to/from the -// device. -// These arrays can either be allocated directly or ::for_transfer with -// a corresponding device_array. -class SHORTFIN_API host_array final : public base_array { - public: - host_array(class storage storage, std::span shape, DType dtype) - : base_array(shape, dtype), storage_(std::move(storage)) {} - - static host_array allocate(local::ScopedDevice &device, - std::span shape, DType dtype) { - return host_array( + // Allocates a host array that is registered by the device. This can include + // arrays that are visible from different combinations of host/device. + static device_array for_host(local::ScopedDevice &device, + std::span shape, DType dtype) { + return device_array( storage::AllocateHost(device, dtype.compute_dense_nd_size(shape)), shape, dtype); } // Allocates a host array for transfer to/from the given device array. - static host_array for_transfer(device_array &with_device_array) { - return allocate(with_device_array.storage().device(), + static device_array for_transfer(device_array &with_device_array) { + return for_host(with_device_array.storage().device(), with_device_array.shape(), with_device_array.dtype()); } - class storage &storage() { return storage_; } - local::ScopedDevice &device() { return storage_.device(); } - std::string to_s() const; + // Untyped access to the backing data. The array must be mappable. Specific + // access modes: + // * data(): Read-only access to the data. + // * data_rw(): Read/write access to the data. + // * data_w(): Write-only access to the data with discard (initial contents + // are undefined.) + const mapping data() const; + mapping data(); + // Map the array's data for read-write untyped access. + mapping data_rw(); + // Map the array's data for write-only untyped access. + mapping data_w(); + + // Maps memory for bridging to xtensor. If mapping is unsupported, return {}. + std::optional map_memory_for_xtensor(); + + // Typed access to the backing data. + template + typed_mapping typed_data() { + return typed_mapping(data()); + } + template + typed_mapping typed_data() const { + return typed_mapping(data()); + } + template + typed_mapping typed_data_rw() { + return typed_mapping(data_rw()); + } + template + typed_mapping typed_data_w() { + return typed_mapping(data_w()); + } - private: + std::string to_s() const override; + + protected: class storage storage_; }; diff --git a/libshortfin/src/shortfin/array/array_test.cc b/libshortfin/src/shortfin/array/array_test.cc new file mode 100644 index 000000000..2c435b292 --- /dev/null +++ b/libshortfin/src/shortfin/array/array_test.cc @@ -0,0 +1,62 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include +#include + +#include "shortfin/array/api.h" +#include "shortfin/local/systems/host.h" + +using namespace shortfin; +using namespace shortfin::local; +using namespace shortfin::array; + +namespace { + +class DeviceArrayTest : public testing::Test { + protected: + DeviceArrayTest() {} + + void SetUp() override { + system = systems::HostCPUSystemBuilder().CreateSystem(); + scope = system->CreateScope(system->init_worker(), system->devices()); + device = scope->device(0); + } + void TearDown() override { + system->Shutdown(); + system.reset(); + } + + SystemPtr system; + std::shared_ptr scope; + ScopedDevice device; +}; + +TEST_F(DeviceArrayTest, contents_to_s_valid) { + device_array ary1 = device_array::for_host( + device, std::to_array({2, 3}), DType::float32()); + { + auto map = ary1.typed_data_w(); + std::fill(map.begin(), map.end(), 42.0); + } + + std::optional contents = ary1.contents_to_s(); + ASSERT_TRUE(contents); + EXPECT_EQ(*contents, "{{ 42., 42., 42.},\n { 42., 42., 42.}}"); +} + +TEST_F(DeviceArrayTest, contents_to_s_invalid) { + device_array ary1 = device_array::for_host( + device, std::to_array({2, 3}), DType::opaque32()); + // No xtensor adaptor for opaque32. + std::optional contents = ary1.contents_to_s(); + ASSERT_FALSE(contents); +} + +} // namespace diff --git a/libshortfin/src/shortfin/array/dims.h b/libshortfin/src/shortfin/array/dims.h new file mode 100644 index 000000000..529aebc42 --- /dev/null +++ b/libshortfin/src/shortfin/array/dims.h @@ -0,0 +1,256 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_ARRAY_DIMS_H +#define SHORTFIN_ARRAY_DIMS_H + +#include +#include +#include + +#include "shortfin/support/api.h" + +namespace shortfin::array { + +// Vector-alike for storing inlined dims. Note that this has a template +// signature identical to std::vector because xtensor specializes on this +// exact signature. See the concrete size_t instantiation below. +template > +class SHORTFIN_API InlinedDims { + public: + using element_type = T; + using value_type = T; + using allocator_type = Alloc; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using reference = value_type &; + using const_reference = const value_type &; + using pointer = value_type *; + using const_pointer = const value_type *; + + class iterator { + public: + using difference_type = std::ptrdiff_t; + using value_type = T; + using pointer = T *; + using reference = T &; + using iterator_category = std::random_access_iterator_tag; + iterator(pointer p) : p(p) {} + iterator &operator++() { + p++; + return *this; + } + iterator &operator++(int) { + p++; + return *this; + } + bool operator==(iterator other) const { return p == other.p; } + bool operator!=(iterator other) const { return p != other.p; } + reference operator*() { return *p; } + + private: + pointer p; + }; + class const_iterator { + public: + using difference_type = std::ptrdiff_t; + using value_type = const T; + using pointer = const T *; + using reference = const T &; + using iterator_category = std::random_access_iterator_tag; + + const_iterator(pointer p) : p(p) {} + const_iterator &operator++() { + p++; + return *this; + } + const_iterator &operator++(int) { + p++; + return *this; + } + bool operator==(const_iterator other) const { return p == other.p; } + bool operator!=(const_iterator other) const { return p != other.p; } + reference operator*() { return *p; } + + private: + pointer p; + }; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + InlinedDims() { new (&dims_.inline_dims) InlineTy(); } + InlinedDims(size_type count, T value = T()) : size_(count) { + if (size_ > MAX_INLINE_RANK) { + // Dynamic allocation. + new (&dims_.dynamic_dims) DynamicTy(new element_type[size_]); + std::fill(dims_.dynamic_dims.get(), dims_.dynamic_dims.get() + size_, + value); + } else { + // Inline allocation. + new (&dims_.inline_dims) InlineTy(); + std::fill(dims_.inline_dims.begin(), dims_.inline_dims.end(), value); + } + } + InlinedDims(const InlinedDims &other) { + new (&dims_.inline_dims) InlineTy(); + set(other.span()); + } + InlinedDims(InlinedDims &&other) : size_(other.size_) { + // Custom move the dims to avoid an additional allocation. This could just + // be a memcpy on most impls, but this is the "right way". + if (size_ > MAX_INLINE_RANK) { + // Dynamic allocation. + new (&dims_.dynamic_dims) DynamicTy(); + dims_.dynamic_dims = std::move(other.dims_.dynamic_dims); + } else { + // Inline allocation. + new (&dims_.inline_dims) InlineTy(); + dims_.inline_dims = other.dims_.inline_dims; + } + other.size_ = 0; + } + InlinedDims &operator=(const InlinedDims &other) { + set(other.span()); + return *this; + } + ~InlinedDims() { clear(); } + + T *data() { + if (size_ > MAX_INLINE_RANK) { + return dims_.dynamic_dims.get(); + } else { + return &dims_.inline_dims.front(); + } + } + const T *data() const { + if (size_ > MAX_INLINE_RANK) { + return dims_.dynamic_dims.get(); + } else { + return &dims_.inline_dims.front(); + } + } + std::size_t size() const { return size_; } + bool empty() const { return size_ == 0; } + + // Clears shape, setting the rank to zero and deleting any non-inline + // dimension storage. + void clear() { + if (size_ > MAX_INLINE_RANK) { + dims_.dynamic_dims.~unique_ptr(); + } else { + dims_.inline_dims.~array(); + } + size_ = 0; + } + + void set(std::span dims) { + clear(); + size_ = dims.size(); + if (size_ > MAX_INLINE_RANK) { + // Dynamic allocation. + new (&dims_.dynamic_dims) DynamicTy(new element_type[size_]); + std::copy(dims.begin(), dims.end(), dims_.dynamic_dims.get()); + } else { + // Inline allocation. + new (&dims_.inline_dims) InlineTy(); + std::copy(dims.begin(), dims.end(), dims_.inline_dims.begin()); + } + } + + // Container access. + iterator begin() { return iterator(data()); } + iterator end() { return iterator(data() + size()); } + const_iterator begin() const { return const_iterator(data()); } + const_iterator end() const { return const_iterator(data() + size()); } + const_iterator cbegin() const { return const_iterator(data()); } + const_iterator cend() const { return const_iterator(data() + size()); } + + void resize(size_type count) { resize_impl(count, value_type()); } + void resize(size_type count, value_type value) { resize_impl(count, value); } + + reference operator[](std::size_t idx) { return *(data() + idx); } + const_reference operator[](std::size_t idx) const { return *(data() + idx); } + + reference front() { return *data(); } + const_reference front() const { return *data(); } + reference back() { return *(data() + size() - 1); } + const_reference back() const { return *(data() + size() - 1); } + + // Access as a span. + std::span span() { return std::span(data(), size_); } + std::span span() const { return std::span(data(), size_); } + + private: + void resize_impl(size_type count, value_type value) { + if (count == size()) return; + if (size() > MAX_INLINE_RANK) { + // Currently dynamically allocated. + if (count < size()) { + // Truncate. + if (count < MAX_INLINE_RANK) { + // Switch to inlined. + InlineTy new_array; + for (std::size_t i = 0; i < count; ++i) + new_array[i] = dims_.dynamic_dims[i]; + dims_.dynamic_dims.~unique_ptr(); + new (&dims_.inline_dims) InlineTy(new_array); + size_ = count; + } else { + // Stay dynamic and just truncate. + size_ = count; + } + } else { + // Expand and stay dynamic. + DynamicTy new_array(new element_type[count]); + for (std::size_t i = 0; i < size_; ++i) + new_array[i] = dims_.dynamic_dims[i]; + for (std::size_t i = size_; i < count; ++i) new_array[i] = value; + dims_.dynamic_dims = std::move(new_array); + size_ = count; + } + } else { + // Currently inlined. + if (count < size()) { + // Truncate. + size_ = count; + } else if (count < MAX_INLINE_RANK) { + // Stay inlined and initialize new items. + for (std::size_t i = size_; i < count; ++i) + dims_.inline_dims[i] = value; + size_ = count; + } else { + // Need to switch to dynamic size. + DynamicTy new_array(new element_type[count]); + for (std::size_t i = 0; i < size_; ++i) + new_array[i] = dims_.inline_dims[i]; + for (std::size_t i = size_; i < count; ++i) new_array[i] = value; + dims_.inline_dims.~array(); + new (&dims_.dynamic_dims) DynamicTy(std::move(new_array)); + size_ = count; + } + } + } + + static constexpr size_t MAX_INLINE_RANK = 6; + using InlineTy = std::array; + using DynamicTy = std::unique_ptr; + union _D { + _D() {} + ~_D() {} + InlineTy inline_dims; + DynamicTy dynamic_dims; + }; + + std::size_t size_ = 0; + _D dims_; +}; + +extern template class InlinedDims; +using Dims = InlinedDims; + +} // namespace shortfin::array + +#endif // SHORTFIN_ARRAY_DIMS_H diff --git a/libshortfin/src/shortfin/array/dims_test.cc b/libshortfin/src/shortfin/array/dims_test.cc new file mode 100644 index 000000000..287e2fa9a --- /dev/null +++ b/libshortfin/src/shortfin/array/dims_test.cc @@ -0,0 +1,148 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/array/dims.h" + +#include +#include + +#include + +namespace shortfin::array { + +TEST(array_dims, empty) { + Dims dims; + EXPECT_TRUE(dims.empty()); + EXPECT_EQ(dims.size(), 0); +} + +TEST(array_dims, inline_init) { + Dims dims(3, 42); + EXPECT_EQ(dims.size(), 3); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dims[i], 42); + } + + Dims copy(dims); + EXPECT_EQ(dims.size(), copy.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), copy.begin())); + EXPECT_TRUE(std::equal(dims.cbegin(), dims.cend(), copy.begin())); + + Dims move = std::move(copy); + EXPECT_EQ(dims.size(), move.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), move.begin())); + + Dims assign; + assign = dims; + EXPECT_EQ(dims.size(), assign.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), assign.begin())); + + EXPECT_EQ(*dims.data(), *assign.data()); + + assign.clear(); + EXPECT_TRUE(assign.empty()); +} + +TEST(array_dims, dynamic_init) { + Dims dims(12, 42); + EXPECT_EQ(dims.size(), 12); + for (size_t i = 0; i < 12; ++i) { + EXPECT_EQ(dims[i], 42); + } + + Dims copy(dims); + EXPECT_EQ(dims.size(), copy.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), copy.begin())); + EXPECT_TRUE(std::equal(dims.cbegin(), dims.cend(), copy.begin())); + + Dims move = std::move(copy); + EXPECT_EQ(dims.size(), move.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), move.begin())); + + Dims assign; + assign = dims; + EXPECT_EQ(dims.size(), assign.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), assign.begin())); + + EXPECT_EQ(*dims.data(), *assign.data()); + + assign.clear(); + EXPECT_TRUE(assign.empty()); +} + +TEST(array_dims, resize_same_size) { + Dims dims(3, 64); + dims.resize(3, 32); + EXPECT_EQ(dims.size(), 3); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dims[i], 64); + } +} + +TEST(array_dims, resize_inline_to_inline) { + Dims dims(3, 64); + dims.resize(5, 32); + EXPECT_EQ(dims.size(), 5); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dims[i], 64); + } + for (size_t i = 3; i < 5; ++i) { + EXPECT_EQ(dims[i], 32); + } +} + +TEST(array_dims, resize_inline_to_dynamic) { + Dims dims(3, 64); + dims.resize(12, 32); + EXPECT_EQ(dims.size(), 12); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dims[i], 64); + } + for (size_t i = 3; i < 12; ++i) { + EXPECT_EQ(dims[i], 32); + } +} + +TEST(array_dims, resize_inline_truncate) { + Dims dims(5, 64); + dims.resize(2, 32); + EXPECT_EQ(dims.size(), 2); + for (size_t i = 0; i < 2; ++i) { + EXPECT_EQ(dims[i], 64); + } +} + +TEST(array_dims, resize_dynamic_to_dynamic) { + Dims dims(12, 64); + dims.resize(15, 32); + EXPECT_EQ(dims.size(), 15); + for (size_t i = 0; i < 12; ++i) { + EXPECT_EQ(dims[i], 64); + } + for (size_t i = 12; i < 15; ++i) { + EXPECT_EQ(dims[i], 32); + } +} + +TEST(array_dims, resize_truncate_to_inline) { + Dims dims(12, 64); + dims.resize(3, 32); + EXPECT_EQ(dims.size(), 3); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dims[i], 64); + } +} + +TEST(array_dims, resize_truncate_to_dynamic) { + Dims dims(12, 64); + dims.resize(10, 32); + EXPECT_EQ(dims.size(), 10); + for (size_t i = 0; i < 10; ++i) { + EXPECT_EQ(dims[i], 64); + } +} + +} // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/dtype_test.cc b/libshortfin/src/shortfin/array/dtype_test.cc new file mode 100644 index 000000000..f1dc0477a --- /dev/null +++ b/libshortfin/src/shortfin/array/dtype_test.cc @@ -0,0 +1,34 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/array/dtype.h" + +#include +#include + +#include + +namespace shortfin::array { + +TEST(array_dtype, basics) { + EXPECT_EQ(DType::complex64().name(), "complex64"); + EXPECT_EQ(static_cast(DType::complex64()), + IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64); + EXPECT_TRUE(DType::complex64() == DType::complex64()); + EXPECT_TRUE(DType::complex64() != DType::complex128()); +} + +TEST(array_dtype, compure_dense_nd_size) { + // 0d special case. + EXPECT_EQ(DType::float32().compute_dense_nd_size({}), 4); + // 0 extent special case. + EXPECT_EQ(DType::float32().compute_dense_nd_size(std::array{0, 4}), + 0); + EXPECT_EQ(DType::float32().compute_dense_nd_size(std::array{2, 4}), + 32); +} + +} // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/storage.cc b/libshortfin/src/shortfin/array/storage.cc index 6eb37267d..6554f3e74 100644 --- a/libshortfin/src/shortfin/array/storage.cc +++ b/libshortfin/src/shortfin/array/storage.cc @@ -14,6 +14,10 @@ namespace shortfin::array { using namespace local; using namespace local::detail; +// -------------------------------------------------------------------------- // +// storage +// -------------------------------------------------------------------------- // + namespace detail { void ThrowIllegalDeviceAffinity(Device *first, Device *second) { throw std::invalid_argument(fmt::format( @@ -99,7 +103,61 @@ void storage::Fill(const void *pattern, iree_host_size_t pattern_length) { } void storage::CopyFrom(storage &source_storage) { - // TODO + throw std::logic_error("CopyFrom NYI"); +} + +bool storage::is_mappable_for_read() const { + return (iree_hal_buffer_allowed_usage(buffer_) & + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) && + (iree_hal_buffer_allowed_access(buffer_) & + IREE_HAL_MEMORY_ACCESS_READ); +} + +bool storage::is_mappable_for_read_write() const { + return (iree_hal_buffer_allowed_usage(buffer_) & + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) && + (iree_hal_buffer_allowed_access(buffer_) & + (IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE)); +} + +void storage::MapExplicit(mapping &mapping, iree_hal_memory_access_t access) { + assert(access != IREE_HAL_MEMORY_ACCESS_NONE); + mapping.reset(); + SHORTFIN_THROW_IF_ERROR(iree_hal_buffer_map_range( + buffer_, IREE_HAL_MAPPING_MODE_SCOPED, access, + /*byte_offset=*/0, byte_length(), &mapping.mapping_)); + mapping.access_ = access; + mapping.hal_device_ownership_baton_ = + iree::hal_device_ptr::borrow_reference(hal_device_ownership_baton_); +} + +iree_hal_memory_type_t storage::memory_type() const { + return iree_hal_buffer_memory_type(buffer_); +} +iree_hal_memory_access_t storage::memory_access() const { + return iree_hal_buffer_allowed_access(buffer_); +} +iree_hal_buffer_usage_t storage::buffer_usage() const { + return iree_hal_buffer_allowed_usage(buffer_); +} + +// Formatted type and access. +std::string storage::formatted_memory_type() const { + iree_bitfield_string_temp_t temp; + auto sv = iree_hal_memory_type_format(memory_type(), &temp); + return std::string(sv.data, sv.size); +} + +std::string storage::formatted_memory_access() const { + iree_bitfield_string_temp_t temp; + auto sv = iree_hal_memory_access_format(memory_access(), &temp); + return std::string(sv.data, sv.size); +} + +std::string storage::formatted_buffer_usage() const { + iree_bitfield_string_temp_t temp; + auto sv = iree_hal_buffer_usage_format(buffer_usage(), &temp); + return std::string(sv.data, sv.size); } std::string storage::to_s() const { @@ -107,4 +165,21 @@ std::string storage::to_s() const { byte_length()); } +// -------------------------------------------------------------------------- // +// mapping +// -------------------------------------------------------------------------- // + +mapping::mapping() { std::memset(&mapping_, 0, sizeof(mapping_)); } + +mapping::~mapping() noexcept { reset(); } + +void mapping::reset() noexcept { + if (*this) { + // Crash the process on failure to unmap. We don't have a good mitigation, + IREE_CHECK_OK(iree_hal_buffer_unmap_range(&mapping_)); + access_ = IREE_HAL_MEMORY_ACCESS_NONE; + hal_device_ownership_baton_.reset(); + } +} + } // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/storage.h b/libshortfin/src/shortfin/array/storage.h index 644f865ac..36f117cb4 100644 --- a/libshortfin/src/shortfin/array/storage.h +++ b/libshortfin/src/shortfin/array/storage.h @@ -14,6 +14,61 @@ namespace shortfin::array { +// Access to mapped memory. +// Mappings are moveable but not copyable. When default constructed or moved +// from, they will not be valid and have nullptr semantics. +class SHORTFIN_API mapping { + public: + mapping(); + mapping(const mapping &) = delete; + mapping &operator=(const mapping &) = delete; + mapping &operator=(mapping &&other) { + access_ = other.access_; + mapping_ = other.mapping_; + hal_device_ownership_baton_ = std::move(other.hal_device_ownership_baton_); + other.access_ = IREE_HAL_MEMORY_ACCESS_NONE; + std::memset(&other.mapping_, 0, sizeof(other.mapping_)); + return *this; + } + mapping(mapping &&other) + : access_(other.access_), + mapping_(other.mapping_), + hal_device_ownership_baton_( + std::move(other.hal_device_ownership_baton_)) { + other.access_ = IREE_HAL_MEMORY_ACCESS_NONE; + std::memset(&other.mapping_, 0, sizeof(other.mapping_)); + } + ~mapping() noexcept; + + // Whether the mapping is valid. + operator bool() const { return access_ != IREE_HAL_MEMORY_ACCESS_NONE; } + + // Resets the mapping, making it invalid (if not already so); + void reset() noexcept; + + // Access the mapped data. The mapping must be valid or else it is UB. + const uint8_t *data() const { + assert(*this && "mapping is not valid"); + return mapping_.contents.data; + } + uint8_t *data() { + assert(*this && "mapping is not valid"); + return mapping_.contents.data; + } + + // The size of the mapped data. Will return 0 if the mapping is not valid. + iree_device_size_t size() const { return mapping_.contents.data_length; } + + bool readable() const { return access_ & IREE_HAL_MEMORY_ACCESS_READ; } + bool writable() const { return access_ & IREE_HAL_MEMORY_ACCESS_WRITE; } + + private: + iree_hal_memory_access_t access_ = IREE_HAL_MEMORY_ACCESS_NONE; + iree_hal_buffer_mapping_t mapping_; + iree::hal_device_ptr hal_device_ownership_baton_; + friend class storage; +}; + // Array storage backed by an IREE buffer of some form. class SHORTFIN_API storage { public: @@ -29,9 +84,9 @@ class SHORTFIN_API storage { // Allocates host storage, compatible with the given device affinity. // By default, if there are any affinity bits set in the device, then - // the storage will be device visible and have permitted usage for transfers. - // This default policy can be overriden based on device defaults or explicit - // options. + // the storage will be device visible and have permitted usage for + // transfers. This default policy can be overriden based on device defaults + // or explicit options. static storage AllocateHost(local::ScopedDevice &device, iree_device_size_t allocation_size); @@ -53,8 +108,58 @@ class SHORTFIN_API storage { return iree_hal_buffer_byte_length(buffer_.get()); } + // Memory type and access. + iree_hal_memory_type_t memory_type() const; + iree_hal_memory_access_t memory_access() const; + iree_hal_buffer_usage_t buffer_usage() const; + + // Formatted type and access. + std::string formatted_memory_type() const; + std::string formatted_memory_access() const; + std::string formatted_buffer_usage() const; + + // Whether the buffer supports host mappable memory. + bool is_mappable_for_read() const; + bool is_mappable_for_read_write() const; + + // Maps the memory for access from a host pointer using a scoped mapping. + void MapExplicit(mapping &mapping, iree_hal_memory_access_t access); + + // Maps the memory for read/write access, preserving any contents. + mapping MapReadWrite() { + mapping m; + MapExplicit(m, IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE); + return m; + } + + // Maps the memory for discard write. This is used if populating an initial + // buffer. + mapping MapWriteDiscard() { + mapping m; + MapExplicit(m, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE); + return m; + } + + // Maps the memory for read-only access. + mapping MapRead() { + mapping m; + MapExplicit(m, IREE_HAL_MEMORY_ACCESS_READ); + return m; + } + + const mapping MapRead() const { + mapping m; + const_cast(this)->MapExplicit(m, IREE_HAL_MEMORY_ACCESS_READ); + return m; + } + std::string to_s() const; + // Access raw buffer. This must not be retained apart from the storage for + // any length of time that may extend its lifetime (as the storage keeps + // underlying device references alive as needed). + operator iree_hal_buffer_t *() { return buffer_; } + private: storage(local::ScopedDevice device, iree::hal_buffer_ptr buffer, local::detail::TimelineResource::Ref timeline_resource) @@ -64,18 +169,57 @@ class SHORTFIN_API storage { device_(device), timeline_resource_(std::move(timeline_resource)) {} // TODO(ownership): Since storage is a free-standing object in the system, - // it needs an ownership baton that keeps the device/driver alive. Otherwise, - // it can outlive the backing device and then then crashes on buffer - // deallocation. For now, we stash an RAII hal_device_ptr, which keeps - // everything alive. This isn't quite what we want but keeps us going for now. - // When fixing, add a test that creates an array, destroys the System, and - // then frees the array. + // it needs an ownership baton that keeps the device/driver alive. + // Otherwise, it can outlive the backing device and then then crashes on + // buffer deallocation. For now, we stash an RAII hal_device_ptr, which + // keeps everything alive. This isn't quite what we want but keeps us going + // for now. When fixing, add a test that creates an array, destroys the + // System, and then frees the array. iree::hal_device_ptr hal_device_ownership_baton_; iree::hal_buffer_ptr buffer_; local::ScopedDevice device_; local::detail::TimelineResource::Ref timeline_resource_; }; +// Wraps an untyped mapping, providing typed access. +template +class typed_mapping { + public: + using span_type = std::span; + using const_span_type = std::span; + + typed_mapping(mapping untyped_mapping) + : untyped_mapping_(std::move(untyped_mapping)) {} + typed_mapping(const typed_mapping &) = delete; + typed_mapping &operator=(const typed_mapping &) = delete; + + iree_device_size_t size() const noexcept { + return untyped_mapping_.size() / sizeof(EltTy); + } + bool empty() const noexcept { return size() == 0; } + EltTy *data() noexcept { + return reinterpret_cast(untyped_mapping_.data()); + } + EltTy *data() const noexcept { + return reinterpret_cast(untyped_mapping_.data()); + } + + span_type span() { return span_type(data(), size()); } + const_span_type span() const { return const_span_type(data(), size()); } + + span_type::iterator begin() { return span().begin(); } + span_type::iterator end() { return span().end(); } + + const_span_type::iterator begin() const { return span().begin(); } + const_span_type::iterator end() const { return span().end(); } + + const_span_type::iterator cbegin() const { return span().begin(); } + const_span_type::iterator cend() const { return span().end(); } + + private: + mapping untyped_mapping_; +}; + } // namespace shortfin::array #endif // SHORTFIN_ARRAY_STORAGE_H diff --git a/libshortfin/src/shortfin/array/xtensor_bridge.cc b/libshortfin/src/shortfin/array/xtensor_bridge.cc new file mode 100644 index 000000000..0dc00f9c7 --- /dev/null +++ b/libshortfin/src/shortfin/array/xtensor_bridge.cc @@ -0,0 +1,93 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/array/xtensor_bridge.h" + +#include + +namespace shortfin::array { + +namespace { + +template +class typed_xt_methods final : public poly_xt_methods { + public: + using xt_specific_t = + decltype(xt::adapt(static_cast(nullptr), Dims())); + // Our specific adaptor type must fit within the memory allocation of the + // generic adaptor type. + static_assert(sizeof(xt_specific_t) <= sizeof(xt_generic_t)); + + xt_specific_t &adaptor() { + return *reinterpret_cast(adaptor_storage); + } + + static void concrete_inplace_new(uint8_t *inst_storage, void *array_memory, + size_t array_memory_size, Dims &dims) { + // We rely on the fact that the typed_xt_methods specialization has the + // exact same memory layout as the base class. + static_assert(sizeof(typed_xt_methods) == sizeof(poly_xt_methods)); + + typed_xt_methods *methods = + reinterpret_cast(inst_storage); + new (methods) typed_xt_methods(); + new (methods->adaptor_storage) + xt_specific_t(xt::adapt(static_cast(array_memory), dims)); + } + + void inplace_destruct_this() override { + adaptor().~xt_specific_t(); + this->~typed_xt_methods(); + } + + std::string contents_to_s() override { + std::stringstream out; + out << adaptor(); + return out.str(); + } +}; +} // namespace + +bool poly_xt_methods::inplace_new(uint8_t *inst_storage, DType dtype, + void *array_memory, size_t array_memory_size, + Dims &dims) { +#define POLY_XT_CASE(et, cpp_type) \ + case et: \ + typed_xt_methods::concrete_inplace_new( \ + inst_storage, array_memory, array_memory_size, dims); \ + return true + + switch (static_cast(dtype)) { + // Hot comparisons first. + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_FLOAT_32, float); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_INT_32, int32_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_SINT_32, int32_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_UINT_32, uint32_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_INT_64, int64_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_SINT_64, int64_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_UINT_64, uint64_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_INT_8, int8_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_SINT_8, int8_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_UINT_8, uint8_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_INT_16, int16_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_SINT_16, int16_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_UINT_16, uint16_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_FLOAT_64, double); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_BOOL_8, bool); + // TODO: float16 + // POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_FLOAT_16, TODO); + // TODO: bfloat16 + // POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_BFLOAT_16, TODO); + // TODO: complex64 + // POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64, TODO); + // TODO: complex128 + // POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128, TODO); + } + + return false; +} + +} // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/xtensor_bridge.h b/libshortfin/src/shortfin/array/xtensor_bridge.h new file mode 100644 index 000000000..a3243e03b --- /dev/null +++ b/libshortfin/src/shortfin/array/xtensor_bridge.h @@ -0,0 +1,160 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_ARRAY_XTENSOR_BRIDGE_H +#define SHORTFIN_ARRAY_XTENSOR_BRIDGE_H + +#include + +#include +#include +#include +#include + +#include "shortfin/array/dims.h" +#include "shortfin/array/dtype.h" +#include "shortfin/array/storage.h" + +namespace shortfin::array { + +// Polymorphic trampoline methods to a backing typed, xarray adaptor. This +// allows xtensor facilities to be used in a dtype agnostic fashion. +class SHORTFIN_API poly_xt_methods { + public: + // Prints the contents of the array. + virtual std::string contents_to_s() = 0; + + protected: + // Since we adapt from a pointer-based container with Dims, just pick one + // as a generic version so that we can reserve space in the class for it. + using xt_generic_t = + decltype(xt::adapt(static_cast(nullptr), Dims())); + + // Placement new an appropriate subclass into the provided storage area, + // which must be sized to hold the base class (subclasses are statically + // asserted to be the same size). The appropriate subclass will also placement + // new an appropriate xtensor adaptor into the adaptor_storage field. It is + // statically asserted that the type specific adaptor will fit into the + // storage area reserved. + // Returns true if an appropriate instance is instantiated. False if no + // implementation for the dtype exists. + static bool inplace_new(uint8_t *inst_storage, DType dtype, + void *array_memory, size_t array_memory_size, + Dims &dims); + + // When instantiated via inplace_new, destorys the instance, calling both + // the type specific adaptor destructor and the subclass destructor. + virtual void inplace_destruct_this() = 0; + + uint8_t adaptor_storage[sizeof(xt_generic_t)]; + + template + friend class poly_xt_mixin; +}; + +// Polymorphic xtensor array mixin. Since xt::array is static on element type, +// this class provides a bridge that will polymorphically manage a specialized +// xarray adaptor for a base_array derived class. +// +// This is designed to use via CRTP on a subclass of base_array. +// +// Access is indirected through a heap allocated poly_xt_methods subclass that +// is initialized on-demand by mapping the device memory and constructing an +// appropriate typed subclass. This is done through two layers of generic +// storage (one contained here for the poly_xt_methods subclass and one +// on that class for the concrete xtensor adaptor it contains). The overhead +// on the base_array instance if the xtensor bridge is not used is one pointer. +// On first use, it is a heap allocation and a switch on dtype. +template +class SHORTFIN_API poly_xt_mixin { + public: + poly_xt_mixin() = default; + // Don't copy the poly instance: if it is needed on the copy, it will be + // re-allocated. + poly_xt_mixin(const poly_xt_mixin &other) {} + + std::optional contents_to_s() { + auto *m = optional_xt_methods(); + if (!m) return {}; + return m->contents_to_s(); + } + + std::optional contents_to_s() const { + return const_cast(this)->contents_to_s(); + } + + // Access (potentially instantiating) the polymorphic xt methods trampoline + // for this array. If no xtensor adaptor can be created or if the memory + // is not accessible to the host, returns nullptr. The returned pointer + // must not outlive the creating array. + poly_xt_methods *optional_xt_methods() { + if (poly_) { + return poly_->methods(); + } + DType dtype = derived_this()->dtype(); + auto inst = std::make_unique(); + // CRTP derived class must provide a memory mapping via its + // map_memory_for_xtensor() method. + // This must be typed as MemoryTy and have data() and size() accessors. + std::optional mapping = derived_this()->map_memory_for_xtensor(); + if (!mapping) { + return nullptr; + } + inst->memory = std::move(*mapping); + void *data = static_cast(inst->memory.data()); + size_t data_size = inst->memory.size(); + if (!poly_xt_methods::inplace_new(inst->methods_storage, dtype, data, + data_size, + derived_this()->shape_container())) { + return nullptr; + } + poly_ = std::move(inst); + return poly_.get()->methods(); + } + + // Accesses (potentially instantiating) the polymorphic xt methods trampoline. + // If it cannot be created, throws a std::logic_error. The returned reference + // must not outlive the creating array. + poly_xt_methods &xt_methods() { + auto m = optional_xt_methods(); + if (!m) { + throw std::logic_error(fmt::format( + "No xtensor specialization registered for dtype {} or storage type", + derived_this()->dtype().name())); + } + return *m; + } + + protected: + ~poly_xt_mixin() { + if (poly_) { + // Need to in-place destruct the adaptor and then the methods itself. + poly_->methods()->inplace_destruct_this(); + } + } + + private: + struct PolyInstance { + MemoryTy memory; + uint8_t methods_storage[sizeof(poly_xt_methods)]; + poly_xt_methods *methods() { + return reinterpret_cast(methods_storage); + } + }; + + const DerivedArrayTy *derived_this() const { + return static_cast(this); + } + DerivedArrayTy *derived_this() { return static_cast(this); } + + // If the polymorphic accessor has been instantiated, it will be constructed + // here. + std::unique_ptr poly_; +}; + +} // namespace shortfin::array + +#endif // SHORTFIN_ARRAY_XTENSOR_BRIDGE_H diff --git a/libshortfin/src/shortfin/local/scope.h b/libshortfin/src/shortfin/local/scope.h index fb6d74fd4..cc6ee8329 100644 --- a/libshortfin/src/shortfin/local/scope.h +++ b/libshortfin/src/shortfin/local/scope.h @@ -28,19 +28,25 @@ class SHORTFIN_API Worker; // needed to do thing with some slice of device queues. class SHORTFIN_API ScopedDevice { public: + ScopedDevice() = default; ScopedDevice(Scope &scope, DeviceAffinity affinity) - : scope_(scope), affinity_(affinity) {} + : scope_(&scope), affinity_(affinity) {} ScopedDevice(Scope &scope, Device *device) - : scope_(scope), affinity_(device) {} + : scope_(&scope), affinity_(device) {} + ScopedDevice(const ScopedDevice &other) + : scope_(other.scope_), affinity_(other.affinity_) {} - Scope &scope() const { return scope_; } + Scope &scope() const { + assert(scope_ && "scope must not be null"); + return *scope_; + } DeviceAffinity affinity() const { return affinity_; } Device *raw_device() const { return affinity_.device(); } std::string to_s() const { return affinity().to_s(); } bool operator==(const ScopedDevice &other) const { - return (&scope_ == &other.scope_) && affinity_ == other.affinity_; + return (scope_ == other.scope_) && affinity_ == other.affinity_; } // Returns a future which will be satisfied when the primary device timeline @@ -49,7 +55,7 @@ class SHORTFIN_API ScopedDevice { CompletionEvent OnSync(bool flush = true); private: - Scope &scope_; + Scope *scope_ = nullptr; DeviceAffinity affinity_; }; diff --git a/libshortfin/src/shortfin/support/CMakeLists.txt b/libshortfin/src/shortfin/support/CMakeLists.txt index ec481d2d5..cbe6df89b 100644 --- a/libshortfin/src/shortfin/support/CMakeLists.txt +++ b/libshortfin/src/shortfin/support/CMakeLists.txt @@ -31,7 +31,7 @@ shortfin_cc_component( ) shortfin_gtest_test( - NAME support_test + NAME shortfin_support_test SRCS # Order is specific: lower level tests before higher level. iree_helpers_test.cc diff --git a/libshortfin/tests/amdgpu_system_test.py b/libshortfin/tests/amdgpu_system_test.py index 4c6d1fae0..4b887ea54 100644 --- a/libshortfin/tests/amdgpu_system_test.py +++ b/libshortfin/tests/amdgpu_system_test.py @@ -8,7 +8,7 @@ @pytest.mark.requires_amd_gpu -def test_create_host_cpu_system(): +def test_create_amd_gpu_system(): from _shortfin import lib as sfl sc = sfl.local.amdgpu.SystemBuilder() @@ -18,3 +18,5 @@ def test_create_host_cpu_system(): print(f" DEVICE: {device_name} = {ls.device(device_name)}") print(ls.devices) + print("Shutting down") + ls.shutdown() diff --git a/libshortfin/tests/array_test.py b/libshortfin/tests/array_test.py index 41cf51aa8..9f53da1c3 100644 --- a/libshortfin/tests/array_test.py +++ b/libshortfin/tests/array_test.py @@ -8,13 +8,15 @@ import pytest import time -from _shortfin import lib as sfl +import shortfin as sf @pytest.fixture def lsys(): - sc = sfl.local.host.CPUSystemBuilder() - return sc.create_system() + sc = sf.host.CPUSystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() @pytest.fixture @@ -25,32 +27,79 @@ def scope(lsys): def test_storage(scope): - storage = sfl.array.storage.allocate_device(scope.device(0), 32) + storage = sf.array.storage.allocate_host(scope.device(0), 32) print(storage) - ary = sfl.array.device_array(storage, [2, 4], sfl.array.float32) + ary = sf.array.device_array(storage, [2, 4], sf.array.float32) print(ary) print(ary.shape) assert ary.shape == [2, 4] - assert ary.dtype == sfl.array.float32 + assert ary.dtype == sf.array.float32 + + print("ARY.DEVICE=", ary.device, ary.device.__class__) + print("SCOPE.DEVICE=", scope.device(0)) + print("EQ:", ary.device == scope.device(0)) + assert ary.device == scope.device(0) + # Mapping API contract. + with storage.map(read=True) as m: + assert m.valid + mv = memoryview(m) + assert len(mv) == 32 + assert not m.valid + + storage.data = array.array("f", [1.234534523] * 8) + print("WRITTEN:", ary) + + read_back = array.array("f") + read_back.frombytes(storage.data) + print("READ BACK:", read_back) + + +@pytest.mark.parametrize( + "dtype,code,py_value,expected_repr", + [ + (sf.array.int8, "b", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), + (sf.array.int16, "h", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), + (sf.array.int32, "i", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), + ( + sf.array.float32, + "f", + 42.0, + "{{ 42., 42., 42., 42.},\n { 42., 42., 42., 42.}}", + ), + ( + sf.array.float64, + "d", + 42.0, + "{{ 42., 42., 42., 42.},\n { 42., 42., 42., 42.}}", + ), + ], +) +def test_xtensor_types(scope, dtype, code, py_value, expected_repr): + ary = sf.array.device_array.for_host(scope.device(0), [2, 4], dtype) + ary.storage.data = array.array(code, [py_value] * 8) + r = repr(ary) + print(r) + assert expected_repr in r, f"Expected '{expected_repr}' in '{r}'" + def test_device_array(scope): - ary1 = sfl.array.device_array(scope.device(0), [32, 1, 4], sfl.array.float32) + ary1 = sf.array.device_array(scope.device(0), [32, 1, 4], sf.array.float32) print(ary1) assert ary1.shape == [32, 1, 4] - assert ary1.dtype == sfl.array.float32 + assert ary1.dtype == sf.array.float32 assert scope.device(0) == ary1.device - hary1 = sfl.array.host_array(ary1) + hary1 = sf.array.device_array.for_transfer(ary1) print(hary1) - assert isinstance(hary1, sfl.array.host_array) + assert isinstance(hary1, sf.array.device_array) assert hary1.shape == ary1.shape assert hary1.dtype == ary1.dtype assert hary1.device == ary1.device def test_device_array_fill(scope): - ary1 = sfl.array.device_array(scope.device(0), [32, 1, 4], sfl.array.int32) - ary1.storage.fill(array.array("i", [0])) + ary1 = sf.array.device_array(scope.device(0), [32, 1, 4], sf.array.int32) + ary1.storage.fill(array.array("i", [42])) # TODO: Transfer to host and verify. diff --git a/libshortfin/tests/local_scope_test.py b/libshortfin/tests/local_scope_test.py index de3598711..9f56e7833 100644 --- a/libshortfin/tests/local_scope_test.py +++ b/libshortfin/tests/local_scope_test.py @@ -13,7 +13,9 @@ @pytest.fixture def lsys(): sc = sfl.local.host.CPUSystemBuilder() - return sc.create_system() + ls = sc.create_system() + yield ls + ls.shutdown() @pytest.fixture diff --git a/libshortfin/tests/smoke_test.py b/libshortfin/tests/smoke_test.py deleted file mode 100644 index a066d6eb7..000000000 --- a/libshortfin/tests/smoke_test.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - - -def test_sfl_import(): - from _shortfin import lib as sfl - - sfl.initialize() From b1d09dbf1496949929c9aaf11f3cd9c551975958 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Mon, 26 Aug 2024 06:33:52 +0300 Subject: [PATCH 12/20] Add GEMM to ops and an example of exporting sharded GEMM (#127) --- .../examples/sharding/export_gemm.py | 107 ++++++++++++++++++ sharktank/sharktank/ops/_registry.py | 30 ++++- sharktank/sharktank/ops/default_impls.py | 44 +++++-- sharktank/sharktank/ops/sharded_impls.py | 20 +++- sharktank/sharktank/ops/signatures.py | 41 +++++++ sharktank/sharktank/types/tensors.py | 15 +++ sharktank/tests/ops/ops_test.py | 12 ++ sharktank/tests/ops/sharded_test.py | 19 ++++ 8 files changed, 279 insertions(+), 9 deletions(-) create mode 100644 sharktank/sharktank/examples/sharding/export_gemm.py diff --git a/sharktank/sharktank/examples/sharding/export_gemm.py b/sharktank/sharktank/examples/sharding/export_gemm.py new file mode 100644 index 000000000..7a4322e38 --- /dev/null +++ b/sharktank/sharktank/examples/sharding/export_gemm.py @@ -0,0 +1,107 @@ +import sys +from typing import List +import argparse +import torch +from torch import Tensor +from sharktank import ops +from shark_turbine import aot + + +def export_gemm( + mlir_path: str, + device_count: int, + m: int, + n: int, + k: int, + with_alpha: bool, + with_beta: bool, +): + class GemmModule(torch.nn.Module): + def forward(self, *args, **kwargs): + return ops.gemm(*args, **kwargs) + + a = torch.empty(m, k, dtype=torch.float32) + b = torch.empty(k, n, dtype=torch.float32) + c = torch.empty(m, n, dtype=torch.float32) + sharded_a = ops.reshard_split(a, dim=0, count=device_count) + sharded_b = ops.replicate(b, count=device_count) + sharded_c = ops.reshard_split(c, dim=0, count=device_count) + gemm_module = GemmModule() + kwargs = { + "a": sharded_a, + "b": sharded_b, + "c": sharded_c, + } + # Need to pass alpha and beta not as numbers, but as tensors since + # the IREE FX importer does not support ConstantArgument. + if with_alpha: + kwargs["alpha"] = torch.tensor(2.0, dtype=torch.float32) + if with_alpha: + kwargs["beta"] = torch.tensor(3.0, dtype=torch.float32) + torch_exported = torch.export.export(gemm_module, args=(), kwargs=kwargs) + export_output = aot.export(torch_exported) + export_output.save_mlir(mlir_path) + + +def export_gemm_cli(argv: List[str]): + parser = argparse.ArgumentParser( + description=""" +Export sharded GEMM to MLIR. +alpha * a @ b + beta * c +a is MxK matrix. +b is KxN matrix. +c is MxN matrix. +The sharded/split dimension is M. +a and c will be split across dimension 0 (M). +b will be replicated on all devices. +For n devices the exported function will have signature +(a0, a1, ..., an, b0, b1, ..., bn, c0, c1, ..., cn) -> (r0, r1, ..., rn), +where ai and ci are the respective shards on the i-th device. +bi is equal to b, but on the i-th device. +The caller must place the shards on the expected devices. + +The result is split along dimension M also, +where ri is on the i-th device. + +Support for --with-alpha and --with-beta is under construction. + +Example usage: +python export_gemm.py --device_count=2 --m=10, --k=20, --n=30 \\ + --mlir=sharded-gemm.mlir""", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--mlir", help="Path to the exported program.", type=str, required=True + ) + parser.add_argument( + "--device_count", help="Number of shards/devices", type=int, required=True + ) + parser.add_argument("--m", help="M", type=int, default=512) + parser.add_argument("--n", help="N", type=int, default=512) + parser.add_argument("--k", help="K", type=int, default=512) + parser.add_argument( + "--with-alpha", + help="Have alpha as an argument to the function signature", + default=False, + action="store_true", + ) + parser.add_argument( + "--with-beta", + help="Have alpha as an argument to the function signature", + default=False, + action="store_true", + ) + args = parser.parse_args(args=argv[1:]) + export_gemm( + mlir_path=args.mlir, + device_count=args.device_count, + m=args.m, + n=args.n, + k=args.k, + with_alpha=args.with_alpha, + with_beta=args.with_beta, + ) + + +if __name__ == "__main__": + export_gemm_cli(sys.argv) diff --git a/sharktank/sharktank/ops/_registry.py b/sharktank/sharktank/ops/_registry.py index 66fa034f3..c519af75b 100644 --- a/sharktank/sharktank/ops/_registry.py +++ b/sharktank/sharktank/ops/_registry.py @@ -17,8 +17,10 @@ from ..types import PrimitiveTensor, QuantizedTensor __all__ = [ + "AllOfExprs", "AllOfType", "AnyOfType", + "IsOfType", "overridable", "SignatureDispatcher", "BoolTypeExpr", @@ -62,6 +64,29 @@ def __call__(self, *args: type) -> bool: return self._expr(*args) +class AllOfExprs(BoolTypeExpr): + """Returns True if all types match their respective boolean type expression. + + ```python + # True. int == int and str in (float, str). + AllOfExprs(IsOfType(int), IsOfType(float, str))(int, str) + + # False. str is not in (int, float). + AllOfExprs(IsOfType(int), IsOfType(int, float))(int, str) + ``` + """ + + def __init__(self, *exprs: BoolTypeExpr): + self._exprs = exprs + + def expr(*types: type): + if len(types) < len(self._exprs): + return False + return all([e(t) for e, t in zip(self._exprs, types)]) + + super().__init__(expr) + + class AllOfType(BoolTypeExpr): """Returns True if all of the types are from a set of types. @@ -109,6 +134,9 @@ def expr(*types: type): super().__init__(expr) +IsOfType = AllOfType + + class SignatureDispatcher: """Replaces an overridable function with a tensor type base dispatcher. @@ -201,7 +229,7 @@ def _is_type_expr_target( ): if len(override_type_spec) > 1: raise TypeError( - "Override with multiple arguments not allowed when using BoolTypeExpr." + f"Override with multiple arguments not allowed when using BoolTypeExpr. Type spec: {override_type_spec}" ) return True return False diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index ed11eb01a..4ca60cc49 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -7,15 +7,16 @@ # This file contains overrides of the standard ops for normal torch and # generic primitive/quantized types. -from typing import Optional, List, Sequence +from typing import Optional, List, Sequence, Union import torch from torch import Tensor, dtype import torch.nn.functional as F +from numbers import Number -from ..types import PrimitiveTensor, QuantizedTensor -from ..types.tensors import unbox_tensor -from ._registry import AllOfType +from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor +from ..types.tensors import unbox_tensor, AnyTensor +from ._registry import AllOfType, AllOfExprs, IsOfType from .signatures import * @@ -60,7 +61,6 @@ def conv2d_default( conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default) conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default) - # Elementwise @elementwise.override(Tensor) def elementwise_unary(operator, x): @@ -68,10 +68,15 @@ def elementwise_unary(operator, x): return operator(x) -@elementwise.override(Tensor, Tensor) +@elementwise.override( + AllOfExprs( + IsOfType(Tensor, PrimitiveTensor), IsOfType(Tensor, PrimitiveTensor, Number) + ) +) def elementwise_binary(operator, x, y): x = unbox_tensor(x) - y = unbox_tensor(y) + if isinstance(y, PrimitiveTensor): + y = unbox_tensor(y) return operator(x, y) @@ -94,6 +99,31 @@ def equal_default(a, b) -> bool: return torch.equal(unbox_tensor(a), unbox_tensor(b)) +@gemm.override(AllOfType(Tensor, InferenceTensor)) +def gemm( + a: AnyTensor, + b: AnyTensor, + c: Optional[AnyTensor], + alpha: Optional[Union[Number, AnyTensor]], + beta: Optional[Union[Number, AnyTensor]], + transa: bool, + transb: bool, +) -> bool: + if transa: + a = a.T + if transb: + b = b.T + res = matmul(a, b) + if alpha is not None: + res = alpha * res + if c is not None: + if beta is not None: + res = res + beta * c + else: + res = res + c + return res + + # Group norm. @group_norm_affine.override(Tensor, Tensor, Tensor) def group_norm_affine_default(input, weight, bias, *, num_groups, eps): diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 4fb85d1ff..b1ef57090 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -8,6 +8,7 @@ from torch import Tensor from typing import List, Optional, Sequence import itertools +from numbers import Number from ..types import ( AnyTensor, @@ -248,6 +249,22 @@ def split_elementwise_binary( return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) +@elementwise.override(SplitPrimitiveTensor, Number) +def elementwise_binary_split_lhs_scalar_rhs( + operator, x: SplitPrimitiveTensor, y: Number +): + pt_xs = [unbox_tensor(pt) for pt in x.shards] + partials = [operator(pt_x, y) for pt_x in pt_xs] + return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) + + +@elementwise.override(SplitPrimitiveTensor, Tensor) +def elementwise_binary_split_lhs_tensor_rhs( + operator, x: SplitPrimitiveTensor, y: Tensor +): + return elementwise(operator, x, replicate(y, count=x.shard_count)) + + @elementwise.override(ReplicatedTensor, SplitPrimitiveTensor) def elementwise_binary_replicated_lhs_sharder_rhs( operator, x: ReplicatedTensor, y: SplitPrimitiveTensor @@ -264,8 +281,9 @@ def elementwise_binary_replicated_lhs_sharder_rhs( @elementwise.override(SplitPrimitiveTensor, ReplicatedTensor) def elementwise_binary_split_lhs_replicated_rhs( - operator, x: ReplicatedTensor, y: SplitPrimitiveTensor + operator, x: SplitPrimitiveTensor, y: ReplicatedTensor ): + assert len(y.shape) > 0, "0-rank not supported" if x.shard_count != y.shard_count: raise ValueError( f"Operands' number of shards not equal ({x.shard_count} != {y.shard_count})" diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 7595578ef..07ae56bb1 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -12,6 +12,7 @@ import numbers from torch import Tensor, dtype from ..types import AnyTensor, ShardedTensor, Theta, sharding +from numbers import Number from ._registry import * @@ -22,6 +23,7 @@ "elementwise", "embedding_lookup", "equal", + "gemm", "group_norm_affine", "layer_norm", "interpolate", @@ -210,6 +212,45 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor): d.fail(tensors) +@overridable +def gemm( + a: AnyTensor, + b: AnyTensor, + c: Optional[AnyTensor] = None, + alpha: Optional[Union[Number, AnyTensor]] = None, + beta: Optional[Union[Number, AnyTensor]] = None, + transa: bool = False, + transb: bool = False, +): + """GEMM as defined by BLAS. + `alpha*a*b + beta*c` + If `c` is None it is the zero-filed tensor. + """ + raise NotImplementedError + + +@gemm.trampoline +def _gemm_trampoline( + d: SignatureDispatcher, + a: AnyTensor, + b: AnyTensor, + c: Optional[AnyTensor] = None, + alpha: Optional[Union[Number, AnyTensor]] = None, + beta: Optional[Union[Number, AnyTensor]] = None, + transa: bool = False, + transb: bool = False, +): + tensors = (a, b, c) + for override in d.find_overrides(tensors): + result = override( + a=a, b=b, c=c, alpha=alpha, beta=beta, transa=transa, transb=transb + ) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def group_norm_affine( input: AnyTensor, weight: AnyTensor, bias: AnyTensor, *, num_groups: int, eps: float diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 6acb9e8d2..b48fb1b52 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -284,6 +284,21 @@ def __add__(self, rhs): return elementwise(torch.add, self, rhs) + def __radd__(self, lhs): + # Assumes commutative addition due to torch.elementwise not handling numbers on + # the lhs. + return self.__add__(lhs) + + def __mul__(self, rhs): + from ..ops import elementwise + + return elementwise(torch.mul, self, rhs) + + def __rmul__(self, lhs): + # Assumes commutative multiplication due to torch.elementwise not handling + # numbers on the lhs. + return self.__mul__(lhs) + REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {} diff --git a/sharktank/tests/ops/ops_test.py b/sharktank/tests/ops/ops_test.py index 54469d40a..24a5f91b1 100644 --- a/sharktank/tests/ops/ops_test.py +++ b/sharktank/tests/ops/ops_test.py @@ -90,6 +90,18 @@ def testQuantizedTensorRhs(self): ... +class GemmTest(unittest.TestCase): + def testGemm(self): + a = torch.tensor([[1, 2], [3, 4]]) + b = torch.tensor([[5, 6], [7, 8]]) + c = torch.tensor([[9, 10], [11, 12]]) + alpha = 2 + beta = 3 + expected = alpha * a @ b.T + beta * c + result = ops.gemm(a, b, c, alpha, beta, False, True) + torch.testing.assert_close(result, expected) + + class MatmulTest(unittest.TestCase): def tearDown(self): ops._registry._test_enable_last_op_dispatch(False) diff --git a/sharktank/tests/ops/sharded_test.py b/sharktank/tests/ops/sharded_test.py index a098fd8be..34e5ebca7 100644 --- a/sharktank/tests/ops/sharded_test.py +++ b/sharktank/tests/ops/sharded_test.py @@ -337,6 +337,25 @@ def testNotEqualSharded(self): assert not ops.equal(b_sharded, a_sharded) +class GemmTest(unittest.TestCase): + def testShardedParallelDim(self): + a = torch.rand(4, 3) + b = torch.rand(5, 3) + c = torch.rand(4, 5) + alpha = 2 + beta = 3 + shard_count = 2 + expected = ops.gemm(a, b, c, alpha, beta, False, True) + sharded_a = ops.reshard_split(a, dim=0, count=shard_count) + sharded_c = ops.reshard_split(c, dim=0, count=shard_count) + sharded_result = ops.gemm(sharded_a, b, sharded_c, alpha, beta, False, True) + assert isinstance(sharded_result, SplitPrimitiveTensor) + assert sharded_result.shard_count == 2 + assert sharded_result.shard_dim == 0 + actual = ops.unshard(sharded_result) + torch.testing.assert_close(actual, expected) + + class InterpolateTest(unittest.TestCase): def testInterpolateSplitChannelDim(self): batches = 2 From 40fac2a434f0c2e23b9ccfdc51c14720267e6410 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 26 Aug 2024 18:07:28 +0200 Subject: [PATCH 13/20] [libshortfin] Build debug/asan (#149) This adds a GitHub action to build a debug/asan version of libshortfin. Shallow cloning of dependencies obtained via FetchContent is reverted as the dependencies cannot be fetched in the CI. --- .../ci_linux_x64_asan-libshortfin.yml | 160 ++++++++++++++++++ libshortfin/CMakeLists.txt | 4 - libshortfin/requirements-tests.txt | 1 + 3 files changed, 161 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/ci_linux_x64_asan-libshortfin.yml diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml new file mode 100644 index 000000000..ff10490d3 --- /dev/null +++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml @@ -0,0 +1,160 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - libshortfin - ASan + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + paths: + - '.github/workflows/ci_linux_x64_asan-libshortfin.yml' + - 'libshortfin/**' + +permissions: + contents: read + +env: + PYENV_ROOT: ${{ github.workspace }}/pyenv + PYENV_REF: 9ecd803bffaffb949fbdd8c70cb086227f6a3202 # v2.4.10 + PYTHON_VER: 3.12.3 + CACHE_ASAN_VER: 1 + CACHE_DEPS_VER: 1 + IREE_SOURCE_DIR: ${{ github.workspace }}/iree + LIBSHORTFIN_DIR: ${{ github.workspace }}/libshortfin/ + + +jobs: + setup-python-asan: + name: Setup Python ASan + runs-on: ubuntu-24.04 + + steps: + - name: Cache Python ASan + id: cache-python-asan + uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + with: + path: ${{ env.PYENV_ROOT }} + key: ${{ runner.os }}-python-asan-${{ env.PYENV_REF }}-${{ env.PYTHON_VER }}-v${{ env.CACHE_ASAN_VER }} + lookup-only: 'true' + + - name: Install dependencies + if: steps.cache-python-asan.outputs.cache-hit != 'true' + run: | + sudo apt update + sudo apt install clang lld cmake ninja-build + sudo apt install build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev + + - name: Checkout pyenv + if: steps.cache-python-asan.outputs.cache-hit != 'true' + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + repository: pyenv/pyenv + ref: ${{ env.PYENV_REF }} + path: ${{ env.PYENV_ROOT }} + + - name: Install pyenv & Python + if: steps.cache-python-asan.outputs.cache-hit != 'true' + run: | + cd ${{ env.PYENV_ROOT }} + src/configure && make -C src + export PATH=${{ env.PYENV_ROOT }}/bin:$PATH && eval "$(pyenv init -)" + CC=clang-18 CXX=clang++-18 LDFLAGS="-lstdc++" PYTHON_CONFIGURE_OPTS="--with-address-sanitizer" pyenv install -v -g ${{ env.PYTHON_VER }} + pyenv global ${{ env.PYTHON_VER }}-debug + + + build-and-test: + name: Build and test libshortfin + needs: [setup-python-asan] + runs-on: ubuntu-24.04 + + steps: + - name: Install dependencies + run: | + sudo apt update + sudo apt install clang lld cmake ninja-build + + - name: Checkout repository + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + submodules: false + + - name: Checkout IREE repo + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + repository: iree-org/iree + path: ${{ env.IREE_SOURCE_DIR }} + submodules: false + + - name: Initalize IREE submodules + run : | + cd ${{ env.IREE_SOURCE_DIR }} + git submodule update --init --depth 1 -- third_party/benchmark + git submodule update --init --depth 1 -- third_party/cpuinfo/ + git submodule update --init --depth 1 -- third_party/flatcc + git submodule update --init --depth 1 -- third_party/googletest + git submodule update --init --depth 1 -- third_party/hip-build-deps/ + + - name: Restore Python dependencies cache + id: cache-python-deps-restore + uses: actions/cache/restore@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + with: + path: ${{ env.PYENV_ROOT }} + key: ${{ runner.os }}-python-deps-${{ hashFiles('libshortfin/requirements-tests.txt') }}-v${{ env.CACHE_DEPS_VER }} + + - name: Restore Python ASan cache + id: cache-python-asan + if: steps.cache-python-deps-restore.outputs.cache-hit != 'true' + uses: actions/cache/restore@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + with: + path: ${{ env.PYENV_ROOT }} + key: ${{ runner.os }}-python-asan-${{ env.PYENV_REF }}-${{ env.PYTHON_VER }}-v${{ env.CACHE_ASAN_VER }} + + - name: Set path + run: + echo "${{ env.PYENV_ROOT }}/bin" >> $GITHUB_PATH + + - name: Install Python dependencies + if: steps.cache-python-deps-restore.outputs.cache-hit != 'true' + run: | + eval "$(pyenv init -)" + pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt + + - name: Save Python dependencies cache + if: steps.cache-python-deps-restore.outputs.cache-hit != 'true' + id: cache-python-deps-save + uses: actions/cache/save@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + with: + path: ${{ env.PYENV_ROOT }} + key: ${{ steps.cache-python-deps-restore.outputs.cache-primary-key }} + + - name: Build libshortfin + run: | + eval "$(pyenv init -)" + mkdir ${{ env.LIBSHORTFIN_DIR }}/build + cd ${{ env.LIBSHORTFIN_DIR }}/build + cmake -GNinja \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_C_COMPILER=clang-18 \ + -DCMAKE_CXX_COMPILER=clang++-18 \ + -DCMAKE_LINKER_TYPE=LLD \ + -DSHORTFIN_BUNDLE_DEPS=ON \ + -DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_SOURCE_DIR }} \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ + -DSHORTFIN_ENABLE_ASAN=ON \ + .. + cmake --build . --target all + pip install -v -e . + + - name: Test libshortfin + run: | + eval "$(pyenv init -)" + cd ${{ env.LIBSHORTFIN_DIR }}/build + cmake --build . --target test + cd ${{ env.LIBSHORTFIN_DIR }} + pytest -m "not requires_amd_gpu" diff --git a/libshortfin/CMakeLists.txt b/libshortfin/CMakeLists.txt index 64ef168e5..082577f24 100644 --- a/libshortfin/CMakeLists.txt +++ b/libshortfin/CMakeLists.txt @@ -72,7 +72,6 @@ if(SHORTFIN_BUNDLE_DEPS) fmt GIT_REPOSITORY https://github.com/fmtlib/fmt.git GIT_TAG e69e5f977d458f2650bb346dadf2ad30c5320281 # 10.2.1 (sync with spdlog) - GIT_SHALLOW TRUE ) ## spdlog @@ -82,7 +81,6 @@ if(SHORTFIN_BUNDLE_DEPS) spdlog GIT_REPOSITORY https://github.com/gabime/spdlog.git GIT_TAG 2d4acf8cc321d7783d8f2e22e17a794c6d0e9450 # v1.14.1 - GIT_SHALLOW TRUE ) ## xtl: required for xtensor @@ -90,7 +88,6 @@ if(SHORTFIN_BUNDLE_DEPS) xtl GIT_REPOSITORY https://github.com/xtensor-stack/xtl.git GIT_TAG a7c1c5444dfc57f76620391af4c94785ff82c8d6 # v0.7.7 - GIT_SHALLOW TRUE ) ## xtensor @@ -98,7 +95,6 @@ if(SHORTFIN_BUNDLE_DEPS) xtensor GIT_REPOSITORY https://github.com/xtensor-stack/xtensor.git GIT_TAG 3634f2ded19e0cf38208c8b86cea9e1d7c8e397d # v0.25.0 - GIT_SHALLOW TRUE ) FetchContent_MakeAvailable(fmt spdlog xtl xtensor) diff --git a/libshortfin/requirements-tests.txt b/libshortfin/requirements-tests.txt index 1049b0412..50bdd9831 100644 --- a/libshortfin/requirements-tests.txt +++ b/libshortfin/requirements-tests.txt @@ -1,3 +1,4 @@ +nanobind==2.0.0 pytest requests fastapi From bade2ab428d51ec57c2157b0e4a02c243dde7d1f Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Mon, 26 Aug 2024 13:18:37 -0700 Subject: [PATCH 14/20] Fix bug in llama decode and add tests for direct/paged KVCache (#143) Fixes the decode bug for paged kv cache and adds a couple of tests to compare direct vs. paged kv cache results. TODO: Fix the skipped test for decode that fails for Windows. --------- Signed-off-by: aviator19941 --- sharktank/sharktank/models/llama/llama.py | 2 +- sharktank/tests/models/llama/kv_cache_test.py | 288 ++++++++++++++++++ 2 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 sharktank/tests/models/llama/kv_cache_test.py diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index ea3170122..aaabd3fe6 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -500,7 +500,7 @@ def transact_cache_paged( xv_cache_update, ], transformer_block_index=self.block_index, - seq_positions=start_positions + 1, + seq_positions=start_positions, page_ids=seq_block_ids, ) diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py new file mode 100644 index 000000000..3953b951b --- /dev/null +++ b/sharktank/tests/models/llama/kv_cache_test.py @@ -0,0 +1,288 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +import torch +import torch.nn as nn +from sharktank.models.llama.llama import ( + PagedLlamaAttentionBlock, + PagedKVCache, + DirectKVCache, +) +from sharktank.models.llama.testing import * +from sharktank.layers.rotary_embedding import RotaryEmbeddingLayer +from sharktank.layers import causal_llm + + +class KVCacheTest(unittest.TestCase): + def setUp(self): + self.block_count = 5 + self.seq_len = 16 + self.head_count = 32 + self.head_dim = 128 + self.ffn_dim = 11008 + self.head_count_kv = 32 + self.block_seq_stride = 16 + self.rms_epsilon = 1e-5 + self.rope_dimension_count = 128 + self.max_seq_len = 4096 + self.start_positions = torch.tensor([8]) + self.bs = 1 + self.device = "cpu" + self.attention_dtype = torch.float32 + self.attention_block_theta = make_attention_block_theta( + feature_dim=self.head_count * self.head_dim, + ffn_dim=self.ffn_dim, + dtype=self.attention_dtype, + ) + self.paged_kv_cache = PagedKVCache( + transformer_block_count=self.head_count, + attn_head_count=self.head_count, + attn_head_dim=self.head_dim, + cache_partition_count=2, # One for each of K/V. + block_seq_stride=self.block_seq_stride, + device=self.device, + dtype=self.attention_dtype, + ) + self.direct_kv_cache = DirectKVCache( + block_seq_stride=self.block_seq_stride, + transformer_block_count=self.head_count, + attn_head_count=self.head_count, + attn_head_dim=self.head_dim, + seq_length=self.max_seq_len, + device=self.device, + dtype=self.attention_dtype, + ) + self.attention_embedding = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seq_len, + device=self.device, + use_hf=False, + ) + self.paged_attn_blocks = nn.ModuleList( + [ + PagedLlamaAttentionBlock( + self.attention_block_theta, + block_index=n, + cache=self.paged_kv_cache, + head_count=self.head_count, + head_dim=self.head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + use_hf=False, + ) + for n in range(self.block_count) + ] + ) + self.direct_attn_blocks = nn.ModuleList( + [ + PagedLlamaAttentionBlock( + theta=self.attention_block_theta, + block_index=n, + cache=self.direct_kv_cache, + head_count=self.head_count, + head_dim=self.head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + use_hf=False, + ) + for n in range(self.block_count) + ] + ) + self.paged_cache_state = self.paged_kv_cache.allocate(page_count=128) + self.paged_seq_block_ids = torch.tensor( + [ + [127], + ] + ) + self.direct_cache_state = self.direct_kv_cache.allocate(bs=1) + self.direct_seq_block_ids = torch.tensor( + [ + [0], + ] + ) + self.embedding_batch_mask = self.attention_embedding.compute_batch_mask( + self.start_positions, batch_seq_len=1 + ) + self.model = causal_llm.BaseCausalLMModel( + self.attention_block_theta, context_length=self.max_seq_len + ) + self.prefill_attention_mask = self.model.attention_mask( + self.model.input_mask(self.start_positions, self.seq_len) + ) + + def testDirectAndPagedKVCachePrefill(self): + torch.set_default_dtype(torch.float32) + + paged_input_tensor = make_rand_torch( + (1, self.seq_len, self.head_count * self.head_dim), + dtype=self.attention_dtype, + ) + direct_input_tensor = paged_input_tensor.detach().clone() + # Iterate over paged attention blocks. + for block_idx, paged_block in enumerate(self.paged_attn_blocks): + paged_input_tensor = paged_block( + paged_input_tensor, + embedding=self.attention_embedding, + start_index=0, + attention_mask=self.prefill_attention_mask, + cache_state=self.paged_cache_state, + seq_block_ids=self.paged_seq_block_ids, + ) + # Iterate over direct attention blocks. + for block_idx, direct_block in enumerate(self.direct_attn_blocks): + direct_input_tensor = direct_block( + direct_input_tensor, + embedding=self.attention_embedding, + start_index=0, + attention_mask=self.prefill_attention_mask, + cache_state=self.direct_cache_state, + seq_block_ids=self.direct_seq_block_ids, + ) + page_table = self.paged_kv_cache.unflatten_page_table(self.paged_cache_state) + index_written = self.start_positions.item() + """ + Getting the value of the paged_seq_block_ids, which is the page id we are writing + the K/V cache into. + """ + page_id = self.paged_seq_block_ids[0][0].item() + """ + direct_cache_state is a list of num_transformer_blocks * 2 (one for K and one for V), + so here we index into the first transformer block's keys with self.direct_cache_state[0] + and the first transformer block's values with self.direct_cache_state[1]. Each row + in direct_cache_state is a tensor of [bs, seq_len , attn_heads, attn_dim], so we make sure + the first 8 (start_position) tensors starting at sequence 0 of the seq_len are written to. + """ + updated_direct_cache_state = self.direct_cache_state[0][ + :, :index_written + ].squeeze(0) + """ + paged_cache_state is a list of a single tensor that represents a flattened page table. + Indexing into self.paged_cache_state[0] and unflattening the page table columns to a 6D tensor of: + * transformer block + * cache partition (K or V cache) + * block sequence stride (number of sequence positions per block) + * attention heads + * attention dimensionality + allows us to access the cache partitions for a certain transformer block and sequence in a + certain page_id. For example, page_table[page_id][0, 0, :index_written] lets us access the + first transformer block's K cache for the first 8 (start_positions) tensors starting at + sequence 0. + """ + updated_paged_cache_state = page_table[page_id][0, 0, :index_written] + assert updated_direct_cache_state.shape == updated_paged_cache_state.shape + torch.testing.assert_close( + updated_direct_cache_state, updated_paged_cache_state + ) + + paged_prefill_attn_output = paged_input_tensor + direct_prefill_attn_output = direct_input_tensor + assert paged_prefill_attn_output.shape == direct_prefill_attn_output.shape + torch.testing.assert_close( + paged_prefill_attn_output, direct_prefill_attn_output + ) + + @unittest.skip( + "Bug in Windows decode test for paged_decode_attn_output vs. direct_decode_attn_output" + ) + def testDirectAndPagedKVCacheDecode(self): + torch.set_default_dtype(torch.float32) + self.start_positions.add_(1) + assert self.direct_seq_block_ids.shape[1] == self.paged_seq_block_ids.shape[1] + decode_attention_mask = self.model.decode_attention_mask( + self.model.input_mask( + self.start_positions, self.direct_seq_block_ids.shape[1] * self.seq_len + ) + ) + + token_paged_input_tensor = make_rand_torch( + (1, 1, self.head_count * self.head_dim), dtype=self.attention_dtype + ) + token_direct_input_tensor = token_paged_input_tensor.detach().clone() + + xk_temp = torch.empty( + [ + self.bs, + self.max_seq_len, + self.head_count_kv, + self.head_dim, + ], + dtype=self.attention_dtype, + device=self.device, + ) + xv_temp = torch.empty( + [ + self.bs, + self.max_seq_len, + self.head_count_kv, + self.head_dim, + ], + dtype=self.attention_dtype, + device=self.device, + ) + + # Iterate over paged attention blocks. + for block_idx, paged_block in enumerate(self.paged_attn_blocks): + token_paged_input_tensor = paged_block( + token_paged_input_tensor, + start_positions=self.start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=self.embedding_batch_mask, + attention_mask=decode_attention_mask, + cache_state=self.paged_cache_state, + seq_block_ids=self.paged_seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + + # Iterate over direct attention blocks. + for block_idx, direct_block in enumerate(self.direct_attn_blocks): + token_direct_input_tensor = direct_block( + token_direct_input_tensor, + start_positions=self.start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=self.embedding_batch_mask, + attention_mask=decode_attention_mask, + cache_state=self.direct_cache_state, + seq_block_ids=self.direct_seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + + page_table = self.paged_kv_cache.unflatten_page_table(self.paged_cache_state) + index_written = self.start_positions.item() + page_id = self.paged_seq_block_ids[0][0].item() + updated_direct_cache_state_keys = self.direct_cache_state[0][ + :, index_written + ].squeeze(0) + updated_paged_cache_state_keys = page_table[page_id][0, 0, index_written] + updated_direct_cache_state_values = self.direct_cache_state[1][ + :, index_written + ].squeeze(0) + updated_paged_cache_state_values = page_table[page_id][0, 1, index_written] + assert ( + updated_direct_cache_state_keys.shape + == updated_paged_cache_state_keys.shape + ) + torch.testing.assert_close( + updated_direct_cache_state_keys, updated_paged_cache_state_keys + ) + assert ( + updated_direct_cache_state_values.shape + == updated_paged_cache_state_values.shape + ) + torch.testing.assert_close( + updated_direct_cache_state_values, updated_paged_cache_state_values + ) + + paged_decode_attn_output = token_paged_input_tensor + direct_decode_attn_output = token_direct_input_tensor + assert paged_decode_attn_output.shape == direct_decode_attn_output.shape + torch.testing.assert_close(paged_decode_attn_output, direct_decode_attn_output) + + +if __name__ == "__main__": + unittest.main() From 46b34c9df417c034ff64a598e7968ec420eb2175 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 26 Aug 2024 22:47:02 +0200 Subject: [PATCH 15/20] Don't detect and fail on ODR violations (#152) --- .github/workflows/ci_linux_x64_asan-libshortfin.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml index ff10490d3..e1d6b653e 100644 --- a/.github/workflows/ci_linux_x64_asan-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml @@ -134,6 +134,9 @@ jobs: key: ${{ steps.cache-python-deps-restore.outputs.cache-primary-key }} - name: Build libshortfin + env: + # TODO(#151): Don't ignore ODR violations + ASAN_OPTIONS=detect_odr_violation: 0 run: | eval "$(pyenv init -)" mkdir ${{ env.LIBSHORTFIN_DIR }}/build @@ -152,6 +155,8 @@ jobs: pip install -v -e . - name: Test libshortfin + env: + CTEST_OUTPUT_ON_FAILURE: 1 run: | eval "$(pyenv init -)" cd ${{ env.LIBSHORTFIN_DIR }}/build From 810b2d46aa16b7356ef181a821ab64a320ea8f3b Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:19:43 -0700 Subject: [PATCH 16/20] minor tmp path fix (#154) --- sharktank/sharktank/examples/export_paged_llm_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 54b301160..98b3f1bf4 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -31,7 +31,7 @@ def main(): parser.add_argument( "--output-config", help="Output file path for exported config file", - default="/tmp/batch_llama_v1.json", + default="tmp/batch_llama_v1.json", ) parser.add_argument( "--bs", From 7b11628d9fc54f60a09a19ec7ac44cc8d07d36f2 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 26 Aug 2024 23:53:01 +0200 Subject: [PATCH 17/20] Split running ctest and pytest (#153) --- .github/workflows/ci_linux_x64_asan-libshortfin.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml index e1d6b653e..5c5310e2a 100644 --- a/.github/workflows/ci_linux_x64_asan-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml @@ -154,12 +154,17 @@ jobs: cmake --build . --target all pip install -v -e . - - name: Test libshortfin + - name: Run ctest + if: ${{ !cancelled() }} env: CTEST_OUTPUT_ON_FAILURE: 1 run: | - eval "$(pyenv init -)" cd ${{ env.LIBSHORTFIN_DIR }}/build cmake --build . --target test + + - name: Run pytest + if: ${{ !cancelled() }} + run: | + eval "$(pyenv init -)" cd ${{ env.LIBSHORTFIN_DIR }} pytest -m "not requires_amd_gpu" From 686d9a85dc430e45e599ab84aad76d3fa43ba798 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 26 Aug 2024 16:47:13 -0700 Subject: [PATCH 18/20] [llama] Explicit option to use static tables. (#155) When exporting, it is better to leave table construction to be dynamic and let the compiler move things to initialization time (versus materializing large tables, which can be max_context_length^2). We set static_tables=False unconditionally on export while leaving it True for eager use. Contains a workaround for #156. --- .../sharktank/examples/export_paged_llm_v1.py | 1 + sharktank/sharktank/layers/causal_llm.py | 17 ++++++---- .../sharktank/layers/rotary_embedding.py | 31 ++++++++++++++----- sharktank/sharktank/models/llama/llama.py | 10 ++++++ 4 files changed, 45 insertions(+), 14 deletions(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 98b3f1bf4..78240d614 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -50,6 +50,7 @@ def main(): hp = configs.LlamaHParams.from_gguf_props(dataset.properties) llama_config = LlamaModelConfig(hp) + llama_config.static_tables = False # Rely on the compiler for hoisting tables. llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged" model = PagedLlamaModelV1(dataset.root_theta, llama_config) diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 91f700789..d253af617 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -28,7 +28,8 @@ def __init__( theta: Theta, *, context_length: int, - static_context_mask: bool = True, + static_tables: bool = True, + static_context_mask: bool = False, device: Optional[torch.device] = None, activation_dtype: torch.dtype = torch.float32, attention_dtype: torch.dtype = torch.float32, @@ -39,7 +40,7 @@ def __init__( self.attention_dtype = attention_dtype self.context_length = context_length - if static_context_mask: + if static_tables: self.register_buffer( "causal_context_mask", self.generate_causal_context_mask() ) @@ -66,10 +67,12 @@ def _maximally_negative_value(self, dtype): def generate_causal_context_mask(self) -> torch.Tensor: context_length = self.context_length + unary_broadcast_ones = torch.ones([1, 1], dtype=torch.bool, device=self.device) + context_broadcast_ones = unary_broadcast_ones.expand( + context_length, context_length + ) causal_context_mask = torch.triu( - torch.ones( - [context_length, context_length], dtype=torch.bool, device=self.device - ), + context_broadcast_ones, diagonal=1, )[None, None, :, :] return causal_context_mask @@ -114,9 +117,11 @@ def attention_mask( scenarios can benefit from managing this in different ways. """ if causal_context_mask is None: + # Try to use the statically generated. causal_context_mask = self.causal_context_mask if causal_context_mask is None: - causal_context_mask = self._generate_causal_context_mask() + # Fallback to dynamically generated. + causal_context_mask = self.generate_causal_context_mask() # Combine the causal context mask and input mask. dtype = self.attention_dtype diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 755392522..18984713d 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -21,14 +21,29 @@ def __init__( max_seqlen: int, device: Optional[torch.device] = None, use_hf: bool = False, + static_tables: bool = True, ): super().__init__() + # Force static_tables until compiler limitations are solved. + # See https://github.com/nod-ai/sharktank/issues/156 + static_tables = True self.device = device + self.rope_dimension_count = rope_dimension_count + self.max_seqlen = max_seqlen self.use_hf = use_hf - self._table = self._create_rotary_embed_table( - max_seqlen=max_seqlen, - dim=rope_dimension_count, - ) + if static_tables: + self.register_buffer( + "static_rotary_embed_table", self._create_rotary_embed_table() + ) + else: + self.static_rotary_embed_table = None + + @property + def rotary_embed_table(self): + if self.static_rotary_embed_table is None: + return self._create_rotary_embed_table() + else: + return self.static_rotary_embed_table def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int): # xq_, xk_ shape: bs, sl, _, dim @@ -80,7 +95,7 @@ def create_ordering_tensor(dim): _, sl, _, dim = xq_.shape # Offset the table based on starting position. - freqs_cis = self._table[start_index : start_index + sl, :] + freqs_cis = self.rotary_embed_table[start_index : start_index + sl, :] assert freqs_cis.shape[-1] == dim assert ( freqs_cis.shape[0] >= sl @@ -139,7 +154,7 @@ def compute_batch_mask( ) + start_positions.unsqueeze(1) # Broadcast lookup to [b, ...]. self.trace_tensor("rope.positions_seq", positions_seq) - freqs_cis = self._table[positions_seq] + freqs_cis = self.rotary_embed_table[positions_seq] # Unsqueeze a unit dim for attention heads. broadcast_freqs_cis = freqs_cis.unsqueeze(2) @@ -167,10 +182,10 @@ def apply_batched_mask( def _create_rotary_embed_table( self, - max_seqlen: int, - dim: int, theta_value: float = 10000.0, ): + dim = self.rope_dimension_count + max_seqlen = self.max_seqlen freqs = 1.0 / ( theta_value ** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index aaabd3fe6..984fc6524 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -52,6 +52,14 @@ class LlamaModelConfig: # rotary embedding). use_hf: bool = False + # If true, then the model may pre-initialize certain tables during + # init. This can be better for eager execution but when capturing a program, + # it is often better to preserve the calculation explicitly and rely on + # the compiler to transform it to an initialization time step. This can + # be the difference of many gigabytes of static data being embedded in + # the program and not. + static_tables: bool = True + def create_kv_cache(self) -> BaseKVCache: hp = self.hp if self.kv_cache_type == "direct": @@ -110,6 +118,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): super().__init__( theta, context_length=config.hp.context_length, + static_tables=config.static_tables, device=config.device, activation_dtype=config.activation_dtype, attention_dtype=config.attention_dtype, @@ -131,6 +140,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): max_seqlen=hp.context_length, device=self.device, use_hf=self.use_hf, + static_tables=config.static_tables, ), ) self.add_module( From b80f6569395b37a1806c79002a62d67d1dbfb6da Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 26 Aug 2024 20:59:33 -0700 Subject: [PATCH 19/20] [shortfin] Make TimelineResource hold scopes alive. (#157) Prior to this, TimelineResource held a raw C++ reference to the scope but did not hold a shared_ptr ref to it. The storage class, which holds buffers alive, was keeping the raw IREE buffer and device alive but not the shortfin object hierarchy. On the happy path, this does not cause issues, but during abnormal termination and other race conditions, it can cause lifetime problems. Here we: * Make TimelineResource hold a std::shared_ptr, which transitively keeps everything alive that can be accessed. * Clean up incorrect ordering of fields in key classes, which could cause dependent objects to be destroyed out of order. * Remove cases in iree::object_ptr which could have resulted in incorrect accounting in certain scenarios. * Adds compile-time flags to perform verbose shortfin lifetime logging and application-side IREE reference checks. * Change System::Shutdown() to not clear object references that still may be live (leave that for the destructor). * Correct an issue where drivers could be retained forever, which was then masking another lifetime issue during abnormal termination. * Find a low rate shutdown race triggering use-after-free in the BlockingExecutor and fix it (found by lifetime logging). * Ensure that all tests are ASAN and LSAN clean, even with a new abnormal termination case added (see updates to include buffer copying, which is currently triggering a buffer permission exception, which set this whole triage session in motion). --- .../workflows/ci_linux_x64-libshortfin.yml | 2 +- .../ci_linux_x64_asan-libshortfin.yml | 2 +- libshortfin/CMakeLists.txt | 3 + libshortfin/bindings/python/array_binding.cc | 14 +- .../mobilenet_server/inference_system.py | 20 ++- libshortfin/src/shortfin/array/array.h | 7 +- libshortfin/src/shortfin/array/storage.cc | 47 +++++- libshortfin/src/shortfin/array/storage.h | 33 ++-- libshortfin/src/shortfin/local/scheduler.cc | 31 +++- libshortfin/src/shortfin/local/scheduler.h | 51 ++++-- libshortfin/src/shortfin/local/scope.cc | 16 +- libshortfin/src/shortfin/local/scope.h | 17 +- libshortfin/src/shortfin/local/system.cc | 32 ++-- .../src/shortfin/support/blocking_executor.cc | 13 +- .../support/blocking_executor_test.cc | 15 +- .../src/shortfin/support/iree_concurrency.h | 11 +- .../src/shortfin/support/iree_helpers.cc | 75 ++++++++- .../src/shortfin/support/iree_helpers.h | 154 +++++++++++++++--- .../src/shortfin/support/iree_helpers_test.cc | 1 + libshortfin/src/shortfin/support/logging.h | 20 +++ 20 files changed, 443 insertions(+), 121 deletions(-) diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 20f944c5b..babcf0245 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -103,7 +103,7 @@ jobs: - name: Test libshortfin (full) run: | cd ${{ env.LIBSHORTFIN_DIR }}/build - cmake --build . --target test + ctest --timeout 30 --output-on-failure cd ${{ env.LIBSHORTFIN_DIR }} pytest -s -v -m "not requires_amd_gpu" diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml index 5c5310e2a..14aa26bda 100644 --- a/.github/workflows/ci_linux_x64_asan-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml @@ -160,7 +160,7 @@ jobs: CTEST_OUTPUT_ON_FAILURE: 1 run: | cd ${{ env.LIBSHORTFIN_DIR }}/build - cmake --build . --target test + ctest --timeout 30 --output-on-failure - name: Run pytest if: ${{ !cancelled() }} diff --git a/libshortfin/CMakeLists.txt b/libshortfin/CMakeLists.txt index 082577f24..c5a2f0f6a 100644 --- a/libshortfin/CMakeLists.txt +++ b/libshortfin/CMakeLists.txt @@ -46,6 +46,9 @@ option(SHORTFIN_ENABLE_ASAN "Enable ASAN" OFF) if(SHORTFIN_ENABLE_ASAN) add_compile_options(-fsanitize=address) add_link_options(-fsanitize=address) + + # Enable more ASAN checks. + add_compile_definitions(IREE_SANITIZER_ADDRESS) endif() option(SHORTFIN_SYSTEMS_AMDGPU "Builds for AMD GPU systems" ON) diff --git a/libshortfin/bindings/python/array_binding.cc b/libshortfin/bindings/python/array_binding.cc index fc4694107..9858c2350 100644 --- a/libshortfin/bindings/python/array_binding.cc +++ b/libshortfin/bindings/python/array_binding.cc @@ -123,6 +123,7 @@ void BindArray(py::module_ &m) { PyBufferReleaser py_view_releaser(py_view); self.Fill(py_view.buf, py_view.len); }) + .def("copy_from", [](storage &self, storage &src) { self.CopyFrom(src); }) .def( "map", [](storage &self, bool read, bool write, bool discard) { @@ -232,13 +233,12 @@ void BindArray(py::module_ &m) { py::type(), /*keep_alive=*/device.scope(), device_array::for_host(device, shape, dtype)); }) - .def_static("for_transfer", - [](device_array &existing) { - return custom_new_keep_alive( - py::type(), - /*keep_alive=*/existing.device().scope(), - device_array::for_transfer(existing)); - }) + .def("for_transfer", + [](device_array &self) { + return custom_new_keep_alive( + py::type(), + /*keep_alive=*/self.device().scope(), self.for_transfer()); + }) .def_prop_ro("device", &device_array::device, py::rv_policy::reference_internal) .def_prop_ro("storage", &device_array::storage, diff --git a/libshortfin/examples/python/mobilenet_server/inference_system.py b/libshortfin/examples/python/mobilenet_server/inference_system.py index e2be35910..8ae7773db 100644 --- a/libshortfin/examples/python/mobilenet_server/inference_system.py +++ b/libshortfin/examples/python/mobilenet_server/inference_system.py @@ -12,7 +12,7 @@ import shortfin as sf import shortfin.array as sfnp -MAX_BATCH = 8 +MAX_BATCH = 1 class InferenceRequest(sf.Message): @@ -27,18 +27,23 @@ def __init__(self, program, request_queue, **kwargs): self.program = program self.request_reader = request_queue.reader() self.device = self.scope.device(0) - self.host_staging = sfnp.host_array( - self.device, [MAX_BATCH, 3, 224, 224], sfnp.float32 - ) self.device_input = sfnp.device_array( self.device, [MAX_BATCH, 3, 224, 224], sfnp.float32 ) + self.host_staging = self.device_input.for_transfer() async def run(self): print(f"Inference process: {self.pid}") while request := await self.request_reader(): print(f"[{self.pid}] Got request {request}") - # self.host_staging.data = self.raw_image_data + # TODO: Should really be taking a slice and writing that. For now, + # just writing to the backing storage is the best we have API + # support for. Generally, APIs on storage should be mirrored onto + # the array. + self.host_staging.storage.data = request.raw_image_data + print(self.host_staging) + self.device_input.storage.copy_from(self.host_staging.storage) + print(self.device_input) class Main: @@ -95,7 +100,10 @@ def client(): # Dumb way to prepare some data to feed [1, 3, 224, 224] f32. import array - dummy_data = array.array("f", [0.2] * (3 * 224 * 224)) + dummy_data = array.array( + "f", ([0.2] * (224 * 224)) + ([0.4] * (224 * 224)) + ([-0.2] * (224 * 224)) + ) + # dummy_data = array.array("f", [0.2] * (3 * 224 * 224)) message = InferenceRequest(dummy_data) writer(message) diff --git a/libshortfin/src/shortfin/array/array.h b/libshortfin/src/shortfin/array/array.h index 31deb665a..c3ab6e302 100644 --- a/libshortfin/src/shortfin/array/array.h +++ b/libshortfin/src/shortfin/array/array.h @@ -80,10 +80,9 @@ class SHORTFIN_API device_array shape, dtype); } - // Allocates a host array for transfer to/from the given device array. - static device_array for_transfer(device_array &with_device_array) { - return for_host(with_device_array.storage().device(), - with_device_array.shape(), with_device_array.dtype()); + // Allocates a host array for transfer to/from this array. + device_array for_transfer() { + return for_host(storage().device(), shape(), dtype()); } // Untyped access to the backing data. The array must be mappable. Specific diff --git a/libshortfin/src/shortfin/array/storage.cc b/libshortfin/src/shortfin/array/storage.cc index 6554f3e74..fa9e0f4b8 100644 --- a/libshortfin/src/shortfin/array/storage.cc +++ b/libshortfin/src/shortfin/array/storage.cc @@ -26,6 +26,15 @@ void ThrowIllegalDeviceAffinity(Device *first, Device *second) { } } // namespace detail +storage::storage(local::ScopedDevice device, iree::hal_buffer_ptr buffer, + local::detail::TimelineResource::Ref timeline_resource) + : timeline_resource_(std::move(timeline_resource)), + buffer_(std::move(buffer)), + device_(device) { + logging::construct("array::storage", this); +} +storage::~storage() { logging::destruct("array::storage", this); } + storage storage::AllocateDevice(ScopedDevice &device, iree_device_size_t allocation_size) { if (!device.raw_device()) { @@ -103,7 +112,28 @@ void storage::Fill(const void *pattern, iree_host_size_t pattern_length) { } void storage::CopyFrom(storage &source_storage) { - throw std::logic_error("CopyFrom NYI"); + device_.scope().scheduler().AppendCommandBuffer( + device_, TransactionType::TRANSFER, [&](Account &account) { + // Must depend on the source's mutation dependencies to avoid + // read-before-write hazard. + account.active_deps_extend( + source_storage.timeline_resource_->mutation_barrier()); + // And depend on our own use and mutations dependencies. + account.active_deps_extend(timeline_resource_->use_barrier()); + account.active_deps_extend(timeline_resource_->mutation_barrier()); + + SHORTFIN_THROW_IF_ERROR(iree_hal_command_buffer_copy_buffer( + account.active_command_buffer(), + /*source_ref=*/ + iree_hal_make_buffer_ref(source_storage.buffer_, 0, byte_length()), + /*target_ref=*/ + iree_hal_make_buffer_ref(buffer_, 0, byte_length()))); + + // And move our own mutation barrier to the current pending timeline + // value. + timeline_resource_->set_mutation_barrier( + account.timeline_sem(), account.timeline_idle_timepoint()); + }); } bool storage::is_mappable_for_read() const { @@ -127,8 +157,7 @@ void storage::MapExplicit(mapping &mapping, iree_hal_memory_access_t access) { buffer_, IREE_HAL_MAPPING_MODE_SCOPED, access, /*byte_offset=*/0, byte_length(), &mapping.mapping_)); mapping.access_ = access; - mapping.hal_device_ownership_baton_ = - iree::hal_device_ptr::borrow_reference(hal_device_ownership_baton_); + mapping.timeline_resource_ = timeline_resource_; } iree_hal_memory_type_t storage::memory_type() const { @@ -169,16 +198,22 @@ std::string storage::to_s() const { // mapping // -------------------------------------------------------------------------- // -mapping::mapping() { std::memset(&mapping_, 0, sizeof(mapping_)); } +mapping::mapping() { + logging::construct("array::mapping", this); + std::memset(&mapping_, 0, sizeof(mapping_)); +} -mapping::~mapping() noexcept { reset(); } +mapping::~mapping() noexcept { + logging::destruct("array::mapping", this); + reset(); +} void mapping::reset() noexcept { if (*this) { // Crash the process on failure to unmap. We don't have a good mitigation, IREE_CHECK_OK(iree_hal_buffer_unmap_range(&mapping_)); access_ = IREE_HAL_MEMORY_ACCESS_NONE; - hal_device_ownership_baton_.reset(); + timeline_resource_.reset(); } } diff --git a/libshortfin/src/shortfin/array/storage.h b/libshortfin/src/shortfin/array/storage.h index 36f117cb4..0db73d28f 100644 --- a/libshortfin/src/shortfin/array/storage.h +++ b/libshortfin/src/shortfin/array/storage.h @@ -23,18 +23,17 @@ class SHORTFIN_API mapping { mapping(const mapping &) = delete; mapping &operator=(const mapping &) = delete; mapping &operator=(mapping &&other) { + timeline_resource_ = std::move(other.timeline_resource_); access_ = other.access_; mapping_ = other.mapping_; - hal_device_ownership_baton_ = std::move(other.hal_device_ownership_baton_); other.access_ = IREE_HAL_MEMORY_ACCESS_NONE; std::memset(&other.mapping_, 0, sizeof(other.mapping_)); return *this; } mapping(mapping &&other) - : access_(other.access_), - mapping_(other.mapping_), - hal_device_ownership_baton_( - std::move(other.hal_device_ownership_baton_)) { + : timeline_resource_(std::move(other.timeline_resource_)), + access_(other.access_), + mapping_(other.mapping_) { other.access_ = IREE_HAL_MEMORY_ACCESS_NONE; std::memset(&other.mapping_, 0, sizeof(other.mapping_)); } @@ -63,15 +62,17 @@ class SHORTFIN_API mapping { bool writable() const { return access_ & IREE_HAL_MEMORY_ACCESS_WRITE; } private: + // See note on storage::timeline_resource_. Must be declared first. + local::detail::TimelineResource::Ref timeline_resource_; iree_hal_memory_access_t access_ = IREE_HAL_MEMORY_ACCESS_NONE; iree_hal_buffer_mapping_t mapping_; - iree::hal_device_ptr hal_device_ownership_baton_; friend class storage; }; // Array storage backed by an IREE buffer of some form. class SHORTFIN_API storage { public: + ~storage(); local::ScopedDevice &device() { return device_; } local::Scope &scope() { return device_.scope(); } const local::ScopedDevice &device() const { return device_; } @@ -162,23 +163,13 @@ class SHORTFIN_API storage { private: storage(local::ScopedDevice device, iree::hal_buffer_ptr buffer, - local::detail::TimelineResource::Ref timeline_resource) - : hal_device_ownership_baton_(iree::hal_device_ptr::borrow_reference( - device.raw_device()->hal_device())), - buffer_(std::move(buffer)), - device_(device), - timeline_resource_(std::move(timeline_resource)) {} - // TODO(ownership): Since storage is a free-standing object in the system, - // it needs an ownership baton that keeps the device/driver alive. - // Otherwise, it can outlive the backing device and then then crashes on - // buffer deallocation. For now, we stash an RAII hal_device_ptr, which - // keeps everything alive. This isn't quite what we want but keeps us going - // for now. When fixing, add a test that creates an array, destroys the - // System, and then frees the array. - iree::hal_device_ptr hal_device_ownership_baton_; + local::detail::TimelineResource::Ref timeline_resource); + // The timeline resource holds the back reference to the owning scope, + // which keeps all devices alive. Buffers must be destroyed before devices, + // so this must be declared first. + local::detail::TimelineResource::Ref timeline_resource_; iree::hal_buffer_ptr buffer_; local::ScopedDevice device_; - local::detail::TimelineResource::Ref timeline_resource_; }; // Wraps an untyped mapping, providing typed access. diff --git a/libshortfin/src/shortfin/local/scheduler.cc b/libshortfin/src/shortfin/local/scheduler.cc index 64e4247e6..c5a9fc062 100644 --- a/libshortfin/src/shortfin/local/scheduler.cc +++ b/libshortfin/src/shortfin/local/scheduler.cc @@ -30,6 +30,9 @@ void Account::Initialize() { void Account::Reset() { active_tx_type_ = TransactionType::NONE; + // if (active_command_buffer_) { + // iree_hal_command_buffer_end(active_command_buffer_); + // } active_command_buffer_.reset(); } @@ -67,10 +70,17 @@ CompletionEvent Account::OnSync() { // TimelineResource // -------------------------------------------------------------------------- // -TimelineResource::TimelineResource(iree_allocator_t host_allocator, - size_t semaphore_capacity) { - SHORTFIN_THROW_IF_ERROR(iree_hal_fence_create( - semaphore_capacity, host_allocator, use_barrier_fence_.for_output())); +TimelineResource::TimelineResource(std::shared_ptr scope, + size_t semaphore_capacity) + : scope_(std::move(scope)) { + logging::construct("TimelineResource", this); + SHORTFIN_THROW_IF_ERROR( + iree_hal_fence_create(semaphore_capacity, scope_->host_allocator(), + use_barrier_fence_.for_output())); +} + +TimelineResource::~TimelineResource() { + logging::destruct("TimelineResource", this); } void TimelineResource::use_barrier_insert(iree_hal_semaphore_t *sem, @@ -83,6 +93,19 @@ void TimelineResource::use_barrier_insert(iree_hal_semaphore_t *sem, // Scheduler // -------------------------------------------------------------------------- // +Scheduler::Scheduler(System &system) : system_(system) { + logging::construct("Scheduler", this); +} + +Scheduler::~Scheduler() { + logging::destruct("Scheduler", this); + + // Explicitly reset account state prior to implicit destruction. + for (auto &account : accounts_) { + account.Reset(); + } +} + void Scheduler::Initialize(std::span devices) { for (Device *device : devices) { accounts_.emplace_back(*this, device); diff --git a/libshortfin/src/shortfin/local/scheduler.h b/libshortfin/src/shortfin/local/scheduler.h index 057bfbd9f..2f606ced3 100644 --- a/libshortfin/src/shortfin/local/scheduler.h +++ b/libshortfin/src/shortfin/local/scheduler.h @@ -83,13 +83,35 @@ class SHORTFIN_API TimelineResource { Ref() : res_(nullptr) {} explicit Ref(TimelineResource *res) : res_(res) { res_->Retain(); } Ref(const Ref &other) : res_(other.res_) { res_->Retain(); } - void operator=(const Ref &other) = delete; - Ref(Ref &&other) : res_(other.res_) { other.res_ = nullptr; } - ~Ref() { - if (res_) res_->Release(); + Ref &operator=(const Ref &other) { + if (other.res_ != res_) { + reset(); + if (other.res_) { + other.res_->Retain(); + res_ = other.res_; + } + } + return *this; + } + Ref &operator=(Ref &&other) { + if (other.res_ != res_) { + reset(); + res_ = other.res_; + other.res_ = nullptr; + } + return *this; } + Ref(Ref &&other) : res_(other.res_) { other.res_ = nullptr; } + ~Ref() { reset(); } TimelineResource *operator->() { return res_; } + void reset() { + if (res_) { + res_->Release(); + res_ = nullptr; + } + } + private: TimelineResource *res_; }; @@ -121,13 +143,18 @@ class SHORTFIN_API TimelineResource { } private: - TimelineResource(iree_allocator_t host_allocator, size_t semaphore_capacity); + TimelineResource(std::shared_ptr scope, size_t semaphore_capacity); + ~TimelineResource(); void Retain() { refcnt_++; } void Release() { if (--refcnt_ == 0) delete this; } int refcnt_ = 0; + + // Back reference to the owning scope. + std::shared_ptr scope_; + // Non-owning mutation barrier semaphore and timepoint. The fact that this // is a single semaphore is an implementation detail that may be generalized // in the future should it be necessary to track multiple write sources. @@ -171,11 +198,13 @@ class SHORTFIN_API Account { void Initialize(); void Reset(); Scheduler &scheduler_; + iree::hal_semaphore_ptr sem_; + iree::hal_fence_ptr active_deps_; + iree::hal_command_buffer_ptr active_command_buffer_; + Device *device_; iree_hal_device_t *hal_device_; TransactionType active_tx_type_ = TransactionType::NONE; - iree::hal_fence_ptr active_deps_; - iree::hal_command_buffer_ptr active_command_buffer_; iree_hal_queue_affinity_t active_queue_affinity_bits_; // Timepoint at which this device is considered idle, inclusive of any @@ -193,14 +222,14 @@ class SHORTFIN_API Account { // an eventual submission would submit a duplicate timepoint). This // timepoint is only valid for the local sem_. uint64_t idle_timepoint_ = 0; - iree::hal_semaphore_ptr sem_; friend class Scheduler; }; // Handles scheduling state for a scope. class SHORTFIN_API Scheduler { public: - Scheduler(System &system) : system_(system) {} + Scheduler(System &system); + ~Scheduler(); TransactionMode transaction_mode() const { return tx_mode_; } @@ -224,9 +253,9 @@ class SHORTFIN_API Scheduler { // Gets a fresh TimelineResource which can be used for tracking resource // read/write and setting barriers. Note that these are all allocated fresh // on each call today but may be pooled in the future. - TimelineResource::Ref NewTimelineResource(iree_allocator_t host_allocator) { + TimelineResource::Ref NewTimelineResource(std::shared_ptr scope) { return TimelineResource::Ref( - new TimelineResource(host_allocator, semaphore_count_)); + new TimelineResource(std::move(scope), semaphore_count_)); } System &system() { return system_; } diff --git a/libshortfin/src/shortfin/local/scope.cc b/libshortfin/src/shortfin/local/scope.cc index f0eb9ca77..39784f196 100644 --- a/libshortfin/src/shortfin/local/scope.cc +++ b/libshortfin/src/shortfin/local/scope.cc @@ -21,10 +21,11 @@ namespace shortfin::local { Scope::Scope(std::shared_ptr system, Worker &worker, std::span> devices) - : host_allocator_(system->host_allocator()), - scheduler_(*system), - system_(std::move(system)), + : system_(std::move(system)), + host_allocator_(system_->host_allocator()), + scheduler_(*system_), worker_(worker) { + logging::construct("Scope", this); for (auto &it : devices) { AddDevice(it.first, it.second); } @@ -33,17 +34,18 @@ Scope::Scope(std::shared_ptr system, Worker &worker, Scope::Scope(std::shared_ptr system, Worker &worker, std::span devices) - : host_allocator_(system->host_allocator()), - scheduler_(*system), - system_(std::move(system)), + : system_(std::move(system)), + host_allocator_(system_->host_allocator()), + scheduler_(*system_), worker_(worker) { + logging::construct("Scope", this); for (auto *device : devices) { AddDevice(device->address().logical_device_class, device); } Initialize(); } -Scope::~Scope() = default; +Scope::~Scope() { logging::destruct("Scope", this); } std::string Scope::to_s() const { return fmt::format("Scope(worker='{}', devices=[{}])", worker_.name(), diff --git a/libshortfin/src/shortfin/local/scope.h b/libshortfin/src/shortfin/local/scope.h index cc6ee8329..0cb566b89 100644 --- a/libshortfin/src/shortfin/local/scope.h +++ b/libshortfin/src/shortfin/local/scope.h @@ -91,6 +91,9 @@ class SHORTFIN_API Scope : public std::enable_shared_from_this { // All scopes are created as shared pointers. std::shared_ptr shared_ptr() { return shared_from_this(); } + // The host allocator. + iree_allocator_t host_allocator() { return host_allocator_; } + // The worker that this scope is bound to. Worker &worker() { return worker_; } @@ -126,7 +129,7 @@ class SHORTFIN_API Scope : public std::enable_shared_from_this { } detail::Scheduler &scheduler() { return scheduler_; } detail::TimelineResource::Ref NewTimelineResource() { - return scheduler().NewTimelineResource(host_allocator_); + return scheduler().NewTimelineResource(shared_ptr()); } // Loads a program from a list of modules onto the devices managed by this @@ -141,19 +144,19 @@ class SHORTFIN_API Scope : public std::enable_shared_from_this { void AddDevice(std::string_view device_class, Device *device); void Initialize(); // Called after all devices are added. - iree_allocator_t host_allocator_; + // Back reference to owning system. + std::shared_ptr system_; string_interner interner_; + iree_allocator_t host_allocator_; + detail::Scheduler scheduler_; + Worker &worker_; + // Map of `` to the count of that class contained. std::unordered_map device_class_count_; // Ordered devices. std::vector devices_; // Map of `` to Device. std::unordered_map named_devices_; - detail::Scheduler scheduler_; - - // Back reference to owning system. - std::shared_ptr system_; - Worker &worker_; }; } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/system.cc b/libshortfin/src/shortfin/local/system.cc index 28c8c9654..2eaf3eaf7 100644 --- a/libshortfin/src/shortfin/local/system.cc +++ b/libshortfin/src/shortfin/local/system.cc @@ -19,6 +19,7 @@ namespace shortfin::local { System::System(iree_allocator_t host_allocator) : host_allocator_(host_allocator) { + logging::construct("System", this); SHORTFIN_THROW_IF_ERROR(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, host_allocator_, vm_instance_.for_output())); @@ -27,6 +28,7 @@ System::System(iree_allocator_t host_allocator) } System::~System() { + logging::destruct("System", this); bool needs_shutdown = false; { iree::slim_mutex_lock_guard guard(lock_); @@ -40,6 +42,21 @@ System::~System() { "explicitly for maximum stability."); Shutdown(); } + + // Orderly destruction of heavy-weight objects. + // Shutdown order is important so we don't leave it to field ordering. + vm_instance_.reset(); + + // Devices. + devices_.clear(); + named_devices_.clear(); + retained_devices_.clear(); + + // HAL drivers. + hal_drivers_.clear(); + + // If support for logging refs was compiled in, report now. + iree::detail::LogLiveRefs(); } void System::Shutdown() { @@ -63,20 +80,7 @@ void System::Shutdown() { } } blocking_executor_.Kill(); - local_workers.clear(); - - // Orderly destruction of heavy-weight objects. - // Shutdown order is important so we don't leave it to field ordering. - vm_instance_.reset(); - - // Devices. - devices_.clear(); - named_devices_.clear(); - retained_devices_.clear(); - - // HAL drivers. - hal_drivers_.clear(); } std::shared_ptr System::CreateScope(Worker &worker, @@ -180,7 +184,7 @@ void System::InitializeHalDriver(std::string_view moniker, throw std::logic_error(fmt::format( "Cannot register multiple hal drivers with moniker '{}'", moniker)); } - slot.reset(driver.release()); + slot = std::move(driver); } void System::InitializeHalDevice(std::unique_ptr device) { diff --git a/libshortfin/src/shortfin/support/blocking_executor.cc b/libshortfin/src/shortfin/support/blocking_executor.cc index fc739ec0c..fde3cc593 100644 --- a/libshortfin/src/shortfin/support/blocking_executor.cc +++ b/libshortfin/src/shortfin/support/blocking_executor.cc @@ -59,9 +59,18 @@ void BlockingExecutor::Kill(bool wait, iree_timeout_t warn_timeout) { iree::slim_mutex_lock_guard g(control_mu_); last_live_thread_count = live_thread_count_; total_thread_count = created_thread_count_; + // If transitioned to 0 live threads, there is a short period of time + // that can exist between the scan of the free list above and a task + // getting scheduled. Therefore, the first time we hit this condition, + // enter the inhibited state, which denies further scheduling. Then + // the next time we encounter no live threads, that will be a true + // count. if (live_thread_count_ == 0) { - inhibit_ = true; - break; + if (inhibit_) { + break; + } else { + inhibit_ = true; + } } } diff --git a/libshortfin/src/shortfin/support/blocking_executor_test.cc b/libshortfin/src/shortfin/support/blocking_executor_test.cc index 78f99cf4a..92a9b31f5 100644 --- a/libshortfin/src/shortfin/support/blocking_executor_test.cc +++ b/libshortfin/src/shortfin/support/blocking_executor_test.cc @@ -13,7 +13,13 @@ namespace shortfin { -TEST(BlockingExecutor, concurrent_tasks) { +class BlockingExecutorTest : public testing::Test { + protected: + void SetUp() override {} + void TearDown() override { iree::detail::LogLiveRefs(); } +}; + +TEST_F(BlockingExecutorTest, concurrent_tasks) { { std::atomic tasks_run{0}; @@ -33,7 +39,7 @@ TEST(BlockingExecutor, concurrent_tasks) { } } -TEST(BlockingExecutor, inhibit_when_shutdown) { +TEST_F(BlockingExecutorTest, inhibit_when_shutdown) { { std::atomic tasks_run{0}; @@ -46,6 +52,7 @@ TEST(BlockingExecutor, inhibit_when_shutdown) { } executor.Kill(/*wait=*/true); + logging::info("Killed"); // New work should be inhibited. try { @@ -57,7 +64,7 @@ TEST(BlockingExecutor, inhibit_when_shutdown) { } } -TEST(BlockingExecutor, warn_deadline) { +TEST_F(BlockingExecutorTest, warn_deadline) { { std::atomic tasks_run{0}; @@ -75,7 +82,7 @@ TEST(BlockingExecutor, warn_deadline) { } } -TEST(BlockingExecutor, threads_recycle) { +TEST_F(BlockingExecutorTest, threads_recycle) { { std::atomic tasks_run{0}; diff --git a/libshortfin/src/shortfin/support/iree_concurrency.h b/libshortfin/src/shortfin/support/iree_concurrency.h index 6ccd1792e..28ef1e99b 100644 --- a/libshortfin/src/shortfin/support/iree_concurrency.h +++ b/libshortfin/src/shortfin/support/iree_concurrency.h @@ -18,8 +18,15 @@ namespace shortfin::iree { namespace detail { struct thread_ptr_helper { - static void retain(iree_thread_t *obj) { iree_thread_retain(obj); } - static void release(iree_thread_t *obj) { iree_thread_release(obj); } + static void steal(iree_thread_t *obj) { LogIREESteal("iree_thread_t", obj); } + static void retain(iree_thread_t *obj) { + LogIREERetain("iree_thread_t", obj); + iree_thread_retain(obj); + } + static void release(iree_thread_t *obj) { + LogIREERelease("iree_thread_t", obj); + iree_thread_release(obj); + } }; }; // namespace detail diff --git a/libshortfin/src/shortfin/support/iree_helpers.cc b/libshortfin/src/shortfin/support/iree_helpers.cc index 8344377b4..d518e99c3 100644 --- a/libshortfin/src/shortfin/support/iree_helpers.cc +++ b/libshortfin/src/shortfin/support/iree_helpers.cc @@ -6,8 +6,81 @@ #include "shortfin/support/iree_helpers.h" +#include + +#include +#include + +#include "shortfin/support/iree_concurrency.h" +#include "shortfin/support/logging.h" + namespace shortfin::iree { +namespace detail { + +#if SHORTFIN_IREE_LOG_RC + +slim_mutex log_mutex; +std::unordered_map app_ref_counts; + +void LogIREERetain(const char *type_name, void *ptr) { + slim_mutex_lock_guard g(log_mutex); + std::string key = fmt::format("{}({})", type_name, ptr); + int &rc = app_ref_counts[key]; + rc += 1; + if (rc == 1) { + logging::info("IREE new {}", key); + } else { + logging::info("IREE retain {} = {}", key, rc); + } +} + +void LogIREERelease(const char *type_name, void *ptr) { + slim_mutex_lock_guard g(log_mutex); + std::string key = fmt::format("{}({})", type_name, ptr); + int &rc = app_ref_counts[key]; + rc -= 1; + if (rc == 0) { + logging::info("IREE delete {}", key); + } else { + logging::info("IREE release {} = {}", key, rc); + } +} + +void LogIREESteal(const char *type_name, void *ptr) { + slim_mutex_lock_guard g(log_mutex); + std::string key = fmt::format("{}({})", type_name, ptr); + int &rc = app_ref_counts[key]; + rc += 1; + if (rc == 1) { + logging::info("IREE steal {}", key); + } else { + logging::info("IREE retain {} = {}", key, rc); + } +} + +void SHORTFIN_API LogLiveRefs() { + slim_mutex_lock_guard g(log_mutex); + bool logged_banner = false; + for (auto &it : app_ref_counts) { + if (it.second == 0) continue; + if (it.second < 0) { + logging::error("Shortfin IREE negative reference count: {} = {}", + it.first, it.second); + continue; + } + if (!logged_banner) { + logged_banner = true; + logging::warn("Shortfin visible live IREE refs remain:"); + } + logging::warn(" Live IREE ref {} = {}", it.first, it.second); + } +} + +#endif + +} // namespace detail + error::error(std::string message, iree_status_t failing_status) : message_(std::move(message)), failing_status_(failing_status) { message_.append(": "); @@ -19,7 +92,7 @@ void error::AppendStatus() const noexcept { status_appended_ = false; iree_allocator_t allocator = iree_allocator_system(); - char* status_buffer = nullptr; + char *status_buffer = nullptr; iree_host_size_t length = 0; if (iree_status_to_string(failing_status_, &allocator, &status_buffer, &length)) { diff --git a/libshortfin/src/shortfin/support/iree_helpers.h b/libshortfin/src/shortfin/support/iree_helpers.h index 8cbe368fd..c77ddbaa8 100644 --- a/libshortfin/src/shortfin/support/iree_helpers.h +++ b/libshortfin/src/shortfin/support/iree_helpers.h @@ -17,6 +17,10 @@ #include "iree/vm/api.h" #include "shortfin/support/api.h" +#if !defined(SHORTFIN_IREE_LOG_RC) +#define SHORTFIN_IREE_LOG_RC 0 +#endif + namespace shortfin { // -------------------------------------------------------------------------- // @@ -36,59 +40,142 @@ namespace iree { namespace detail { +#if SHORTFIN_IREE_LOG_RC +void SHORTFIN_API LogIREERetain(const char *type_name, void *ptr); +void SHORTFIN_API LogIREERelease(const char *type_name, void *ptr); +void SHORTFIN_API LogIREESteal(const char *type_name, void *ptr); +void SHORTFIN_API LogLiveRefs(); +#else +inline void LogIREERetain(const char *type_name, void *ptr) {} +inline void LogIREERelease(const char *type_name, void *ptr) {} +inline void LogIREESteal(const char *type_name, void *ptr) {} +inline void LogLiveRefs() {} +#endif + struct hal_buffer_ptr_helper { - static void retain(iree_hal_buffer_t *obj) { iree_hal_buffer_retain(obj); } - static void release(iree_hal_buffer_t *obj) { iree_hal_buffer_release(obj); } + static void steal(iree_hal_buffer_t *obj) { + LogIREESteal("iree_hal_buffer_t", obj); + } + static void retain(iree_hal_buffer_t *obj) { + LogIREERetain("iree_hal_buffer_t", obj); + iree_hal_buffer_retain(obj); + } + static void release(iree_hal_buffer_t *obj) { + LogIREERelease("iree_hal_buffer_t", obj); + iree_hal_buffer_release(obj); + } }; struct hal_command_buffer_helper { + static void steal(iree_hal_command_buffer_t *obj) { + LogIREESteal("iree_hal_command_buffer_t", obj); + } static void retain(iree_hal_command_buffer_t *obj) { + LogIREERetain("iree_hal_command_buffer_t", obj); iree_hal_command_buffer_retain(obj); } static void release(iree_hal_command_buffer_t *obj) { + LogIREERelease("iree_hal_command_buffer_t", obj); iree_hal_command_buffer_release(obj); } }; struct hal_device_ptr_helper { - static void retain(iree_hal_device_t *obj) { iree_hal_device_retain(obj); } - static void release(iree_hal_device_t *obj) { iree_hal_device_release(obj); } + static void steal(iree_hal_device_t *obj) { + LogIREESteal("iree_hal_device_t", obj); + } + static void retain(iree_hal_device_t *obj) { + LogIREERetain("iree_hal_device_t", obj); + iree_hal_device_retain(obj); + } + static void release(iree_hal_device_t *obj) { + LogIREERelease("iree_hal_device_t", obj); + iree_hal_device_release(obj); + } }; struct hal_driver_ptr_helper { - static void retain(iree_hal_driver_t *obj) { iree_hal_driver_retain(obj); } - static void release(iree_hal_driver_t *obj) { iree_hal_driver_release(obj); } + static void steal(iree_hal_driver_t *obj) { + LogIREESteal("iree_hal_driver_t", obj); + } + static void retain(iree_hal_driver_t *obj) { + LogIREERetain("iree_hal_driver_t", obj); + iree_hal_driver_retain(obj); + } + static void release(iree_hal_driver_t *obj) { + LogIREERelease("iree_hal_driver_t", obj); + iree_hal_driver_release(obj); + } }; struct hal_fence_ptr_helper { - static void retain(iree_hal_fence_t *obj) { iree_hal_fence_retain(obj); } - static void release(iree_hal_fence_t *obj) { iree_hal_fence_release(obj); } + static void steal(iree_hal_fence_t *obj) { + LogIREESteal("iree_hal_fence_t", obj); + } + static void retain(iree_hal_fence_t *obj) { + LogIREERetain("iree_hal_fence_t", obj); + iree_hal_fence_retain(obj); + } + static void release(iree_hal_fence_t *obj) { + LogIREERelease("iree_hal_fence_t", obj); + iree_hal_fence_release(obj); + } }; struct hal_semaphore_ptr_helper { + static void steal(iree_hal_semaphore_t *obj) { + LogIREESteal("iree_hal_semaphore_t", obj); + } static void retain(iree_hal_semaphore_t *obj) { + LogIREERetain("iree_hal_semaphore_t", obj); iree_hal_semaphore_retain(obj); } static void release(iree_hal_semaphore_t *obj) { + LogIREERelease("iree_hal_semaphore_t", obj); iree_hal_semaphore_release(obj); } }; struct vm_context_ptr_helper { - static void retain(iree_vm_context_t *obj) { iree_vm_context_retain(obj); } - static void release(iree_vm_context_t *obj) { iree_vm_context_release(obj); } + static void steal(iree_vm_context_t *obj) { + LogIREESteal("iree_vm_context_t", obj); + } + static void retain(iree_vm_context_t *obj) { + LogIREERetain("iree_vm_context_t", obj); + iree_vm_context_retain(obj); + } + static void release(iree_vm_context_t *obj) { + LogIREERelease("iree_vm_context_t", obj); + iree_vm_context_release(obj); + } }; struct vm_instance_ptr_helper { - static void retain(iree_vm_instance_t *obj) { iree_vm_instance_retain(obj); } + static void steal(iree_vm_instance_t *obj) { + LogIREESteal("iree_vm_instance_t", obj); + } + static void retain(iree_vm_instance_t *obj) { + LogIREERetain("iree_vm_instance_t", obj); + iree_vm_instance_retain(obj); + } static void release(iree_vm_instance_t *obj) { + LogIREERelease("iree_vm_instance_t", obj); iree_vm_instance_release(obj); } }; struct vm_module_ptr_helper { - static void retain(iree_vm_module_t *obj) { iree_vm_module_retain(obj); } - static void release(iree_vm_module_t *obj) { iree_vm_module_release(obj); } + static void steal(iree_vm_module_t *obj) { + LogIREESteal("iree_vm_module_t", obj); + } + static void retain(iree_vm_module_t *obj) { + LogIREERetain("iree_vm_module_t", obj); + iree_vm_module_retain(obj); + } + static void release(iree_vm_module_t *obj) { + LogIREERelease("iree_vm_module_t", obj); + iree_vm_module_release(obj); + } }; }; // namespace detail @@ -105,41 +192,60 @@ class object_ptr { } } object_ptr(object_ptr &&other) : ptr(other.ptr) { other.ptr = nullptr; } + object_ptr &operator=(const object_ptr &other) = delete; object_ptr &operator=(object_ptr &&other) { + reset(); ptr = other.ptr; other.ptr = nullptr; return *this; } - ~object_ptr() { - if (ptr) { - Helper::release(ptr); - } - } + ~object_ptr() { reset(); } // Constructs a new object_ptr by transferring ownership of a raw // pointer. - static object_ptr steal_reference(T *owned) { return object_ptr(owned); } + static object_ptr steal_reference(T *owned) { + Helper::steal(owned); + return object_ptr(owned); + } + // Constructs a new object_ptr by retaining a raw pointer. static object_ptr borrow_reference(T *owned) { Helper::retain(owned); return object_ptr(owned); } operator T *() const noexcept { return ptr; } + class Assignment { + public: + explicit Assignment(object_ptr *assign) : assign(assign) {} + ~Assignment() { + if (assign->ptr) { + Helper::steal(assign->ptr); + } + } + + constexpr operator T **() noexcept { + return reinterpret_cast(&assign->ptr); + } + + private: + object_ptr *assign = nullptr; + }; + // Releases any current reference held by this instance and returns a // pointer to the raw backing pointer. This is typically used for passing // to out parameters which are expected to store a new owned pointer directly. - T **for_output() { + constexpr Assignment for_output() noexcept { reset(); - return &ptr; + return Assignment(this); } operator bool() const { return ptr != nullptr; } T *get() const { return ptr; } - void reset(T *other = nullptr) { + void reset() { if (ptr) { Helper::release(ptr); } - ptr = other; + ptr = nullptr; } T *release() { T *ret = ptr; @@ -151,6 +257,8 @@ class object_ptr { // Assumes the reference count for owned_ptr. object_ptr(T *owned_ptr) : ptr(owned_ptr) {} T *ptr = nullptr; + + friend class Assignment; }; using hal_buffer_ptr = diff --git a/libshortfin/src/shortfin/support/iree_helpers_test.cc b/libshortfin/src/shortfin/support/iree_helpers_test.cc index a13b81b72..bf059ee98 100644 --- a/libshortfin/src/shortfin/support/iree_helpers_test.cc +++ b/libshortfin/src/shortfin/support/iree_helpers_test.cc @@ -29,6 +29,7 @@ struct iree_dummy_t { }; struct dummy_ptr_helper { + static void steal(iree_dummy_t *obj) {} static void retain(iree_dummy_t *obj) { obj->retain_count++; } static void release(iree_dummy_t *obj) { obj->release_count++; } }; diff --git a/libshortfin/src/shortfin/support/logging.h b/libshortfin/src/shortfin/support/logging.h index 55bd36347..337ebacae 100644 --- a/libshortfin/src/shortfin/support/logging.h +++ b/libshortfin/src/shortfin/support/logging.h @@ -9,6 +9,10 @@ #include "spdlog/spdlog.h" +#if !defined(SHORTFIN_LOG_LIFETIMES) +#define SHORTFIN_LOG_LIFETIMES 0 +#endif + namespace shortfin::logging { // TODO: Re-export doesn't really work like this. Need to define API @@ -18,6 +22,22 @@ using spdlog::error; using spdlog::info; using spdlog::warn; +#if SHORTFIN_LOG_LIFETIMES +template +inline void construct(const char* type_name, T* inst) { + info("new {}({})", type_name, static_cast(inst)); +} +template +inline void destruct(const char* type_name, T* inst) { + info("delete {}({})", type_name, static_cast(inst)); +} +#else +template +inline void construct(const char *type_name, T *) {} +template +inline void destruct(const char *type_name, T *) {} +#endif + } // namespace shortfin::logging #endif // SHORTFIN_SUPPORT_LOGGING_H From ba58e4d8f0907785a99c6a36510339889bee7396 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 26 Aug 2024 21:29:29 -0700 Subject: [PATCH 20/20] [shortfin] Add documentation on ownership and lifetime of the system hierarchy. --- libshortfin/src/shortfin/local/system.h | 37 +++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/libshortfin/src/shortfin/local/system.h b/libshortfin/src/shortfin/local/system.h index 3a5bbfd86..cb5c70808 100644 --- a/libshortfin/src/shortfin/local/system.h +++ b/libshortfin/src/shortfin/local/system.h @@ -42,9 +42,40 @@ class SystemBuilder; // on some form of factory that constructs one to suit both the system being // executed on and any preferences on which resources should be accessible. // -// As the root of the hierarchy and the owner of numerous ancillary resources, -// we declare that System is always managed via a shared_ptr, as this -// simplifies many aspects of system management. +// Ownership +// --------- +// There are three levels of ownership, all rooted on the System: +// 1. System: The System class, all drivers, devices, workers, and executors. +// There will only ever be one (or a small number if doing something multi +// tenant), and all owning references to the System are via +// `std::shared_ptr`. Every object in the system must either be +// a managed child of the system or own a system reference. +// 2. Scope: Binds any number of devices to a coherent schedule, rooted on +// a Worker. Scopes are independent of the system and there are generally +// as many as needed logical concurrency in the application. Each scope +// holds a system reference by way of a `std::shared_ptr`. These +// are still heavy-weight objects mostly created at initialization time +// and are therefore managed held as a `std::shared_ptr` by anything +// that depends on them. +// 3. TimelineResource: Any resource in the system (i.e. buffer, +// synchronization, object, etc) will hold a unique TimelineResource. These +// are light-weight objects managed via intrusive reference counting by +// their contained `TimelineResource::Ref` class. Each `TimelineResource` +// maintains a `std::shared_ptr` back reference to its owning +// scope. +// +// Leaf objects can have any lifetime that they wish, so long as they maintain +// an appropriate ownership reference into the System hierarchy above. This +// includes any application managed objects like arrays, storage, processes, +// messages, queues, etc. +// +// Lifetime debug logging can be enabled via compiler defines: +// SHORTFIN_LOG_LIFETIMES=1 : Enables constructor/destructor and this pointer +// logging for the primary objects in the system hierarchy. +// SHORTFIN_IREE_LOG_RC=1 : Enables the application view of IREE object +// reference counting, showing steal/retain/release and the number of +// references the application holds for each object. Also will log any +// outstanding references when the System is deallocated. class SHORTFIN_API System : public std::enable_shared_from_this { public: System(iree_allocator_t host_allocator);