From 08c69aac1ce7d044b2ed24f94ee03a78d1b89cb1 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 3 Sep 2024 14:48:28 -0700 Subject: [PATCH] [libshortfin] Implement invocation. (#159) --- .../workflows/ci_linux_x64-libshortfin.yml | 2 +- .../ci_linux_x64_asan-libshortfin.yml | 1 + libshortfin/bindings/python/array_binding.cc | 126 +++-- libshortfin/bindings/python/lib_ext.cc | 159 +++++- .../bindings/python/shortfin/__init__.py | 6 + .../mobilenet_server/inference_system.py | 44 +- libshortfin/requirements-tests.txt | 2 +- libshortfin/src/shortfin/array/CMakeLists.txt | 1 + libshortfin/src/shortfin/array/array.cc | 61 ++- libshortfin/src/shortfin/array/array.h | 51 +- libshortfin/src/shortfin/array/dims.h | 5 +- libshortfin/src/shortfin/array/dtype.cc | 19 + libshortfin/src/shortfin/array/dtype.h | 65 +-- libshortfin/src/shortfin/array/dtypes.inl | 34 ++ libshortfin/src/shortfin/array/storage.cc | 100 +++- libshortfin/src/shortfin/array/storage.h | 62 ++- libshortfin/src/shortfin/local/async.cc | 2 +- libshortfin/src/shortfin/local/async.h | 10 + libshortfin/src/shortfin/local/device.h | 1 + libshortfin/src/shortfin/local/messaging.h | 2 +- libshortfin/src/shortfin/local/program.cc | 484 +++++++++++++++++- libshortfin/src/shortfin/local/program.h | 243 ++++++++- .../src/shortfin/local/program_interfaces.h | 85 +++ libshortfin/src/shortfin/local/scheduler.cc | 62 ++- libshortfin/src/shortfin/local/scheduler.h | 14 +- libshortfin/src/shortfin/local/scope.cc | 43 -- libshortfin/src/shortfin/local/scope.h | 8 - libshortfin/src/shortfin/local/worker.h | 1 + .../src/shortfin/support/iree_concurrency.h | 16 +- .../src/shortfin/support/iree_helpers.h | 200 ++------ libshortfin/src/shortfin/support/logging.h | 8 + 31 files changed, 1505 insertions(+), 412 deletions(-) create mode 100644 libshortfin/src/shortfin/array/dtypes.inl create mode 100644 libshortfin/src/shortfin/local/program_interfaces.h diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index babcf0245..b9fcb8777 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -83,8 +83,8 @@ jobs: - name: Install Python packages # TODO: Switch to `pip install -r requirements.txt -e libshortfin/`. run: | - pip install nanobind pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt + pip freeze - name: Build libshortfin (full) run: | diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml index 14aa26bda..f0ecc5452 100644 --- a/.github/workflows/ci_linux_x64_asan-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml @@ -124,6 +124,7 @@ jobs: run: | eval "$(pyenv init -)" pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt + pip freeze - name: Save Python dependencies cache if: steps.cache-python-deps-restore.outputs.cache-hit != 'true' diff --git a/libshortfin/bindings/python/array_binding.cc b/libshortfin/bindings/python/array_binding.cc index 9858c2350..f294c80a4 100644 --- a/libshortfin/bindings/python/array_binding.cc +++ b/libshortfin/bindings/python/array_binding.cc @@ -13,6 +13,23 @@ using namespace shortfin::array; namespace shortfin::python { namespace { +static const char DOCSTRING_ARRAY_COPY_FROM[] = + R"(Copy contents from a source array to this array. + +Equivalent to `dest_array.storage.copy_from(source_array.storage)`. +)"; + +static const char DOCSTRING_ARRAY_COPY_TO[] = + R"(Copy contents this array to a destination array. + +Equivalent to `dest_array.storage.copy_from(source_array.storage)`. +)"; + +static const char DOCSTRING_ARRAY_FILL[] = R"(Fill an array with a value. + +Equivalent to `array.storage.fill(pattern)`. +)"; + static const char DOCSTRING_STORAGE_DATA[] = R"(Access raw binary contents. Accessing `foo = storage.data` is equivalent to `storage.data.map(read=True)`. @@ -28,6 +45,23 @@ As with `map`, this will only work on buffers that are host visible, which includes all host buffers and device buffers created with the necessary access. )"; +static const char DOCSTRING_STORAGE_COPY_FROM[] = + R"(Copy contents from a source storage to this array. + +This operation executes asynchronously and the effect will only be visible +once the execution scope has been synced to the point of mutation. +)"; + +static const char DOCSTRING_STORAGE_FILL[] = R"(Fill a storage with a value. + +Takes as argument any value that can be interpreted as a buffer with the Python +buffer protocol of size 1, 2, or 4 bytes. The storage will be filled uniformly +with the pattern. + +This operation executes asynchronously and the effect will only be visible +once the execution scope has been synced to the point of mutation. +)"; + static const char DOCSTRING_STORAGE_MAP[] = R"(Create a mapping of the buffer contents in host memory. @@ -72,58 +106,47 @@ void BindArray(py::module_ &m) { .def(py::self == py::self) .def("__repr__", &DType::name); - m.attr("opaque8") = DType::opaque8(); - m.attr("opaque16") = DType::opaque16(); - m.attr("opaque32") = DType::opaque32(); - m.attr("opaque64") = DType::opaque64(); - m.attr("bool8") = DType::bool8(); - m.attr("int4") = DType::int4(); - m.attr("sint4") = DType::sint4(); - m.attr("uint4") = DType::uint4(); - m.attr("int8") = DType::int8(); - m.attr("sint8") = DType::sint8(); - m.attr("uint8") = DType::uint8(); - m.attr("int16") = DType::int16(); - m.attr("sint16") = DType::sint16(); - m.attr("uint16") = DType::uint16(); - m.attr("int32") = DType::int32(); - m.attr("sint32") = DType::sint32(); - m.attr("uint32") = DType::uint32(); - m.attr("int64") = DType::int64(); - m.attr("sint64") = DType::sint64(); - m.attr("uint64") = DType::uint64(); - m.attr("float16") = DType::float16(); - m.attr("float32") = DType::float32(); - m.attr("float64") = DType::float64(); - m.attr("bfloat16") = DType::bfloat16(); - m.attr("complex64") = DType::complex64(); - m.attr("complex128") = DType::complex128(); +#define SHORTFIN_DTYPE_HANDLE(et, ident) m.attr(#ident) = DType::ident(); +#include "shortfin/array/dtypes.inl" +#undef SHORTFIN_DTYPE_HANDLE // storage py::class_(m, "storage") + .def("__sfinv_marshal__", + [](device_array *self, py::capsule inv_capsule, int barrier) { + auto *inv = + static_cast(inv_capsule.data()); + static_cast(self) + ->AddAsInvocationArgument( + inv, static_cast(barrier)); + }) .def_static( "allocate_host", [](local::ScopedDevice &device, iree_device_size_t allocation_size) { - return storage::AllocateHost(device, allocation_size); + return storage::allocate_host(device, allocation_size); }, py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>()) .def_static( "allocate_device", [](local::ScopedDevice &device, iree_device_size_t allocation_size) { - return storage::AllocateDevice(device, allocation_size); + return storage::allocate_device(device, allocation_size); }, py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>()) - .def("fill", - [](storage &self, py::handle buffer) { - Py_buffer py_view; - int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND. - if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) { - throw py::python_error(); - } - PyBufferReleaser py_view_releaser(py_view); - self.Fill(py_view.buf, py_view.len); - }) - .def("copy_from", [](storage &self, storage &src) { self.CopyFrom(src); }) + .def( + "fill", + [](storage &self, py::handle buffer) { + Py_buffer py_view; + int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND. + if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) { + throw py::python_error(); + } + PyBufferReleaser py_view_releaser(py_view); + self.fill(py_view.buf, py_view.len); + }, + py::arg("pattern"), DOCSTRING_STORAGE_FILL) + .def( + "copy_from", [](storage &self, storage &src) { self.copy_from(src); }, + py::arg("source_storage"), DOCSTRING_STORAGE_COPY_FROM) .def( "map", [](storage &self, bool read, bool write, bool discard) { @@ -137,7 +160,7 @@ void BindArray(py::module_ &m) { } mapping *cpp_mapping = nullptr; py::object py_mapping = CreateMappingObject(&cpp_mapping); - self.MapExplicit( + self.map_explicit( *cpp_mapping, static_cast(access)); return py_mapping; @@ -154,12 +177,12 @@ void BindArray(py::module_ &m) { [](storage &self) { mapping *cpp_mapping = nullptr; py::object py_mapping = CreateMappingObject(&cpp_mapping); - *cpp_mapping = self.MapRead(); + *cpp_mapping = self.map_read(); return py_mapping; }, [](storage &self, py::handle buffer_obj) { PyBufferRequest src_info(buffer_obj, PyBUF_SIMPLE); - auto dest_data = self.MapWriteDiscard(); + auto dest_data = self.map_write_discard(); if (src_info.view().len > dest_data.size()) { throw std::invalid_argument( fmt::format("Cannot write {} bytes into buffer of {} bytes", @@ -219,6 +242,14 @@ void BindArray(py::module_ &m) { py_type, /*keep_alive=*/device.scope(), device_array::for_device(device, shape, dtype)); }) + .def("__sfinv_marshal__", + [](device_array *self, py::capsule inv_capsule, int barrier) { + auto *inv = + static_cast(inv_capsule.data()); + static_cast(self) + ->AddAsInvocationArgument( + inv, static_cast(barrier)); + }) .def_static("for_device", [](local::ScopedDevice &device, std::span shape, DType dtype) { @@ -243,6 +274,17 @@ void BindArray(py::module_ &m) { py::rv_policy::reference_internal) .def_prop_ro("storage", &device_array::storage, py::rv_policy::reference_internal) + + .def( + "fill", + [](py::handle_t self, py::handle buffer) { + self.attr("storage").attr("fill")(buffer); + }, + py::arg("pattern"), DOCSTRING_ARRAY_FILL) + .def("copy_from", &device_array::copy_from, py::arg("source_array"), + DOCSTRING_ARRAY_COPY_FROM) + .def("copy_to", &device_array::copy_to, py::arg("dest_array"), + DOCSTRING_ARRAY_COPY_TO) .def("__repr__", &device_array::to_s); } diff --git a/libshortfin/bindings/python/lib_ext.cc b/libshortfin/bindings/python/lib_ext.cc index 6072caa04..e15467634 100644 --- a/libshortfin/bindings/python/lib_ext.cc +++ b/libshortfin/bindings/python/lib_ext.cc @@ -7,6 +7,8 @@ #include "./lib_ext.h" #include "./utils.h" +#include "shortfin/array/array.h" +#include "shortfin/array/storage.h" #include "shortfin/local/async.h" #include "shortfin/local/messaging.h" #include "shortfin/local/process.h" @@ -24,6 +26,13 @@ namespace shortfin::python { namespace { +static const char DOCSTRING_PROGRAM_FUNCTION_INVOCATION[] = + R"(Creates an invocation object targeting the function. + +This is a low-level interface for performing an invocation, and it should be +used when precise, non-default control is needed. +)"; + class Refs { public: py::object asyncio_create_task = @@ -182,6 +191,59 @@ class PyProcess : public local::detail::BaseProcess { std::shared_ptr refs_; }; +void PyAddProgramInvocationArg(py::capsule &inv_capsule, py::handle arg) { + // See if the object implements our marshaling protocol. If it does, then + // We invoke the marshaling method with the Invocation wrapped as a capsule + // and the ProgramResourceBarrier. + py::object marshaler = py::getattr(arg, "__sfinv_marshal__", py::none()); + if (!marshaler.is_none()) { + marshaler(inv_capsule, + static_cast(local::ProgramResourceBarrier::DEFAULT)); + return; + } + + throw std::invalid_argument( + fmt::format("Unsupported argument type {} in call to ProgramFunction", + py::cast(py::repr(arg.type())))); +} + +local::ProgramInvocation::Future PyFunctionCall(local::ProgramFunction &self, + py::args args) { + auto inv = self.CreateInvocation(); + py::capsule inv_capsule(inv.get()); + for (py::handle arg : args) { + PyAddProgramInvocationArg(inv_capsule, arg); + } + return local::ProgramInvocation::Invoke(std::move(inv)); +} + +py::object PyRehydrateRef(local::ProgramInvocation *inv, + iree::vm_opaque_ref ref) { + auto type = ref.get()->type; + // Note that these accessors are dangerous as they assert/abort if + // process-wide registration is not done properly. We assume here that + // since we got a ref out that the basics are set up soundly, but if actually + // doing this on user/dynamic types, we would want to be more defensive. + // TODO: Don't just do a linear scan if we have more than a couple. + // TODO: Find a reliable way to statically cache the type id. + if (local::ProgramInvocationMarshalableFactory::invocation_marshalable_type< + array::device_array>() == type) { + // device_array + return py::cast(local::ProgramInvocationMarshalableFactory:: + CreateFromInvocationResultRef( + inv, std::move(ref))); + } else if (local::ProgramInvocationMarshalableFactory:: + invocation_marshalable_type() == type) { + // storage + return py::cast( + local::ProgramInvocationMarshalableFactory:: + CreateFromInvocationResultRef(inv, std::move(ref))); + } + throw std::invalid_argument( + fmt::format("Cannot marshal ref type {} to Python", + to_string_view(iree_vm_ref_type_name(type)))); +} + py::object RunInForeground(std::shared_ptr refs, local::System &self, py::object coro) { bool is_main_thread = @@ -237,7 +299,7 @@ py::object RunInForeground(std::shared_ptr refs, local::System &self, } // namespace NB_MODULE(lib, m) { - m.def("initialize", shortfin::GlobalInitialize); + py::class_(m, "_OpaqueVmRef"); auto local_m = m.def_submodule("local"); BindLocal(local_m); BindHostSystem(local_m); @@ -379,11 +441,79 @@ void BindLocal(py::module_ &m) { .def("__add__", &local::DeviceAffinity::AddDevice) .def("__repr__", &local::DeviceAffinity::to_s); - py::class_(m, "Program"); + py::class_(m, "Program") + .def(py::new_([](std::span modules, + local::Scope &scope, bool trace_execution) { + local::Program::Options options; + options.trace_execution = trace_execution; + return local::Program::Load(scope.shared_from_this(), modules, + std::move(options)); + }), + py::arg("modules"), py::arg("scope"), py::kw_only(), + py::arg("trace_execution") = false) + .def_prop_ro("exports", &local::Program::exports) + .def("lookup_function", &local::Program::LookupRequiredFunction) + .def("__getitem__", &local::Program::LookupRequiredFunction); + py::class_(m, "ProgramFunction") + .def_prop_ro("name", &local::ProgramFunction::name) + .def_prop_ro("calling_convention", + &local::ProgramFunction::calling_convention) + .def("invocation", &local::ProgramFunction::CreateInvocation, + DOCSTRING_PROGRAM_FUNCTION_INVOCATION) + .def("__call__", PyFunctionCall, py::arg("args")) + .def("__repr__", &local::ProgramFunction::to_s); py::class_(m, "ProgramModule") + .def_prop_ro("exports", &local::ProgramModule::exports) .def("__repr__", &local::ProgramModule::to_s) .def_static("load", &local::ProgramModule::Load, py::arg("system"), py::arg("path"), py::arg("mmap") = true); + py::class_(m, "ProgramInvocation") + .def("invoke", + [](local::ProgramInvocation::Ptr &self) { + if (!self) throw std::invalid_argument("Deallocated invocation"); + return local::ProgramInvocation::Invoke(std::move(self)); + }) + .def("add_arg", + [](local::ProgramInvocation::Ptr &self, py::handle arg) { + if (!self) throw std::invalid_argument("Deallocated invocation"); + py::capsule inv_capsule(self.get()); + PyAddProgramInvocationArg(inv_capsule, arg); + }) + .def("__iter__", + [](local::ProgramInvocation::Ptr &self) { + if (!self) throw std::invalid_argument("Deallocated invocation"); + size_t size = self->results_size(); + py::object tp = py::steal(PyTuple_New(size)); + for (size_t i = 0; i < size; ++i) { + iree::vm_opaque_ref ref = self->result_ref(i); + if (!ref) { + throw new std::logic_error( + "Program returned unsupported Python type"); + } + py::object item = PyRehydrateRef(self.get(), std::move(ref)); + PyTuple_SET_ITEM(tp.ptr(), i, item.release().ptr()); + } + return tp.attr("__iter__")(); + }) + .def( + "__len__", + [](local::ProgramInvocation::Ptr &self) { + if (!self) throw std::invalid_argument("Deallocated invocation"); + return self->results_size(); + }, + "The number of results in this invocation") + .def( + "__getitem__", + [](local::ProgramInvocation::Ptr &self, iree_host_size_t i) { + if (!self) throw std::invalid_argument("Deallocated invocation"); + iree::vm_opaque_ref ref = self->result_ref(i); + if (!ref) { + throw new std::logic_error( + "Program returned unsupported Python type"); + } + return PyRehydrateRef(self.get(), std::move(ref)); + }, + "Gets the i'th result"); struct DevicesSet { DevicesSet(local::Scope &scope) : scope(scope) {} @@ -414,16 +544,7 @@ void BindLocal(py::module_ &m) { [](local::Scope &self, py::args args) { return CastDeviceAffinity(self, args); }, - py::rv_policy::reference_internal) - .def( - "load_unbound_program", - [](local::Scope &scope, std::span modules, - bool trace_execution) { - local::Program::Options options; - options.trace_execution = trace_execution; - return scope.LoadUnboundProgram(modules, std::move(options)); - }, - py::arg("modules"), py::arg("trace_execution") = false); + py::rv_policy::reference_internal); py::class_(m, "ScopedDevice") .def_prop_ro("scope", &local::ScopedDevice::scope, @@ -696,6 +817,20 @@ void BindLocal(py::module_ &m) { return iter_ret; }); py::class_(m, "VoidFuture"); + py::class_( + m, "ProgramInvocationFuture") + .def("result", [](local::ProgramInvocation::Future &self) { + local::ProgramInvocation::Ptr &result = self.result(); + if (!result) return py::none(); + // Sharp edge: ProgramInvocationFutures are read-once since we move the + // ProgramInvocation::Ptr out of the future here and transfer ownership + // to a Python object. There isn't a better way to do this without + // increasing overhead on this hot path or doing something more + // expensive in the C++ API: essentially, ProgramInvocations flow + // through the system precisely one way. As a low level facility, this + // is deemed acceptable. + return py::cast(std::move(result)); + }); py::class_(m, "MessageFuture") .def("result", [](local::MessageFuture &self) { // Get a raw backing msg (without an increased refcount). When cast diff --git a/libshortfin/bindings/python/shortfin/__init__.py b/libshortfin/bindings/python/shortfin/__init__.py index 5426f2ad2..09b4301b0 100644 --- a/libshortfin/bindings/python/shortfin/__init__.py +++ b/libshortfin/bindings/python/shortfin/__init__.py @@ -14,6 +14,9 @@ Node = _sfl.local.Node Process = _sfl.local.Process Program = _sfl.local.Program +ProgramFunction = _sfl.local.ProgramFunction +ProgramInvocation = _sfl.local.ProgramInvocation +ProgramInvocationFuture = _sfl.local.ProgramInvocationFuture ProgramModule = _sfl.local.ProgramModule Queue = _sfl.local.Queue QueueReader = _sfl.local.QueueReader @@ -37,6 +40,9 @@ "Message", "Node", "Program", + "ProgramFunction", + "ProgramInvocation", + "ProgramInvocationFuture", "ProgramModule", "Queue", "QueueReader", diff --git a/libshortfin/examples/python/mobilenet_server/inference_system.py b/libshortfin/examples/python/mobilenet_server/inference_system.py index 8ae7773db..9967e0aa1 100644 --- a/libshortfin/examples/python/mobilenet_server/inference_system.py +++ b/libshortfin/examples/python/mobilenet_server/inference_system.py @@ -24,7 +24,7 @@ def __init__(self, raw_image_data): class InferenceProcess(sf.Process): def __init__(self, program, request_queue, **kwargs): super().__init__(**kwargs) - self.program = program + self.main_function = program["module.torch-jit-export"] self.request_reader = request_queue.reader() self.device = self.scope.device(0) self.device_input = sfnp.device_array( @@ -41,14 +41,46 @@ async def run(self): # support for. Generally, APIs on storage should be mirrored onto # the array. self.host_staging.storage.data = request.raw_image_data - print(self.host_staging) - self.device_input.storage.copy_from(self.host_staging.storage) - print(self.device_input) + print("host_staging =", self.host_staging) + self.device_input.copy_from(self.host_staging) + + # Simple call. Note that the await here is merely awaiting the + # result being *available* (i.e. that the VM coroutine has + # completed) but does not indicate that the result is ready. + (result1,) = await self.main_function(self.device_input) + (result2,) = await self.main_function(self.device_input) + + # TODO: Implement await on individual results. The accounting is + # there but currently we can only await on the device itself. + await self.device + print("Result 1:", result1) + print("Result 2:", result2) + + # Explicit invocation object. + # inv = self.main_function.invocation(scope=self.scope) + # inv.add_arg(self.device_input) + # results = await inv.invoke() + # print("results:", results) + + # Multiple invocations in parallel. + # all_results = await asyncio.gather( + # self.main_function(self.device_input, scope=self.scope), + # self.main_function(self.device_input, scope=self.scope), + # self.main_function(self.device_input, scope=self.scope), + # ) + # print("All results:", all_results) + + # output = await self.scope.invoke(self.main_function, self.device_input) + # print("OUTPUT:", output) + # read_back = self.device_input.for_transfer() + # read_back.copy_from(self.device_input) + # await self.device + # print("read back =", read_back) class Main: def __init__(self, lsys: sf.System, home_dir: Path): - self.processes_per_worker = 1 + self.processes_per_worker = 2 self.lsys = lsys self.home_dir = home_dir self.request_queue = lsys.create_queue("request") @@ -60,8 +92,8 @@ async def start_scope(self, scope): # Note that currently, program load is synchronous. But we do it # in a task so we can await it in the future and let program loads # overlap. - program = scope.load_unbound_program([self.program_module]) for _ in range(self.processes_per_worker): + program = sf.Program([self.program_module], scope=scope) self.processes.append( InferenceProcess(program, self.request_queue, scope=scope).launch() ) diff --git a/libshortfin/requirements-tests.txt b/libshortfin/requirements-tests.txt index 50bdd9831..a5392667d 100644 --- a/libshortfin/requirements-tests.txt +++ b/libshortfin/requirements-tests.txt @@ -1,4 +1,4 @@ -nanobind==2.0.0 +nanobind==2.1.0 pytest requests fastapi diff --git a/libshortfin/src/shortfin/array/CMakeLists.txt b/libshortfin/src/shortfin/array/CMakeLists.txt index 0e9360363..22cf3fd68 100644 --- a/libshortfin/src/shortfin/array/CMakeLists.txt +++ b/libshortfin/src/shortfin/array/CMakeLists.txt @@ -12,6 +12,7 @@ shortfin_cc_component( api.h dims.h dtype.h + dtypes.inl storage.h SRCS array.cc diff --git a/libshortfin/src/shortfin/array/array.cc b/libshortfin/src/shortfin/array/array.cc index 1d6d7cc5a..224d1177a 100644 --- a/libshortfin/src/shortfin/array/array.cc +++ b/libshortfin/src/shortfin/array/array.cc @@ -14,25 +14,25 @@ namespace shortfin::array { -template class InlinedDims; +template class InlinedDims; // -------------------------------------------------------------------------- // // device_array // -------------------------------------------------------------------------- // -const mapping device_array::data() const { return storage_.MapRead(); } +const mapping device_array::data() const { return storage_.map_read(); } -mapping device_array::data() { return storage_.MapRead(); } +mapping device_array::data() { return storage_.map_read(); } -mapping device_array::data_rw() { return storage_.MapReadWrite(); } +mapping device_array::data_rw() { return storage_.map_read_write(); } -mapping device_array::data_w() { return storage_.MapWriteDiscard(); } +mapping device_array::data_w() { return storage_.map_write_discard(); } std::optional device_array::map_memory_for_xtensor() { if (storage_.is_mappable_for_read_write()) { - return storage_.MapReadWrite(); + return storage_.map_read_write(); } else if (storage_.is_mappable_for_read()) { - return storage_.MapRead(); + return storage_.map_read(); } return {}; } @@ -52,10 +52,49 @@ std::string device_array::to_s() const { } } - return fmt::format("device_array([{}], dtype='{}', device={}({})) ={}{}", - fmt::join(shape(), ", "), dtype().name(), - storage_.device().to_s(), storage_.formatted_memory_type(), - contents_prefix, contents); + return fmt::format( + "device_array([{}], dtype='{}', device={}(type={}, usage={}, access={})) " + "={}{}", + fmt::join(shape(), ", "), dtype().name(), storage_.device().to_s(), + storage_.formatted_memory_type(), storage_.formatted_buffer_usage(), + storage_.formatted_memory_access(), contents_prefix, contents); +} + +void device_array::AddAsInvocationArgument( + local::ProgramInvocation *inv, local::ProgramResourceBarrier barrier) { + auto dims_span = shape(); + iree_hal_buffer_view_t *buffer_view; + SHORTFIN_THROW_IF_ERROR(iree_hal_buffer_view_create( + storage_, dims_span.size(), dims_span.data(), dtype(), + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, storage_.host_allocator(), + &buffer_view)); + + iree::vm_opaque_ref ref; + *(&ref) = iree_hal_buffer_view_move_ref(buffer_view); + inv->AddArg(std::move(ref)); + + storage().AddInvocationArgBarrier(inv, barrier); +} + +iree_vm_ref_type_t device_array::invocation_marshalable_type() { + return iree_hal_buffer_view_type(); +} + +device_array device_array::CreateFromInvocationResultRef( + local::ProgramInvocation *inv, iree::vm_opaque_ref ref) { + // We don't retain the buffer view in the device array, so just deref it + // vs stealing the ref. + iree_hal_buffer_view_t *bv = iree_hal_buffer_view_deref(*ref.get()); + iree::hal_buffer_ptr buffer = + iree::hal_buffer_ptr::borrow_reference(iree_hal_buffer_view_buffer(bv)); + + auto imported_storage = + storage::ImportInvocationResultStorage(inv, std::move(buffer)); + std::span shape(iree_hal_buffer_view_shape_dims(bv), + iree_hal_buffer_view_shape_rank(bv)); + return device_array( + std::move(imported_storage), shape, + DType::import_element_type(iree_hal_buffer_view_element_type(bv))); } } // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/array.h b/libshortfin/src/shortfin/array/array.h index c3ab6e302..30b9d5e14 100644 --- a/libshortfin/src/shortfin/array/array.h +++ b/libshortfin/src/shortfin/array/array.h @@ -16,6 +16,7 @@ #include "shortfin/array/dtype.h" #include "shortfin/array/storage.h" #include "shortfin/array/xtensor_bridge.h" +#include "shortfin/local/program_interfaces.h" #include "shortfin/support/api.h" namespace shortfin::array { @@ -23,7 +24,8 @@ namespace shortfin::array { // Either a host or device nd-array view. class SHORTFIN_API base_array { public: - base_array(std::span shape, DType dtype) : dtype_(dtype) { + base_array(std::span shape, DType dtype) + : dtype_(dtype) { set_shape(shape); } // Need to explicitly define copy/move constructors even though this is @@ -38,9 +40,9 @@ class SHORTFIN_API base_array { DType dtype() const { return dtype_; } // Access shape. - void set_shape(std::span shape) { shape_.set(shape); } - std::span shape() const { return shape_.span(); } - std::span mutable_shape() { return shape_.span(); } + void set_shape(std::span shape) { shape_.set(shape); } + std::span shape() const { return shape_.span(); } + std::span mutable_shape() { return shape_.span(); } // Sometimes we need to access the raw shape container (i.e. for adapters, // etc). @@ -54,9 +56,10 @@ class SHORTFIN_API base_array { class SHORTFIN_API device_array : public base_array, - public poly_xt_mixin { + public poly_xt_mixin, + public local::ProgramInvocationMarshalable { public: - device_array(class storage storage, std::span shape, + device_array(class storage storage, std::span shape, DType dtype) : base_array(shape, dtype), storage_(std::move(storage)) {} @@ -65,18 +68,20 @@ class SHORTFIN_API device_array // Allocate an array on the device. static device_array for_device(local::ScopedDevice &device, - std::span shape, DType dtype) { + std::span shape, + DType dtype) { return device_array( - storage::AllocateDevice(device, dtype.compute_dense_nd_size(shape)), + storage::allocate_device(device, dtype.compute_dense_nd_size(shape)), shape, dtype); } // Allocates a host array that is registered by the device. This can include // arrays that are visible from different combinations of host/device. static device_array for_host(local::ScopedDevice &device, - std::span shape, DType dtype) { + std::span shape, + DType dtype) { return device_array( - storage::AllocateHost(device, dtype.compute_dense_nd_size(shape)), + storage::allocate_host(device, dtype.compute_dense_nd_size(shape)), shape, dtype); } @@ -85,6 +90,23 @@ class SHORTFIN_API device_array return for_host(storage().device(), shape(), dtype()); } + // Enqueues a fill of the storage with an arbitrary pattern of the given + // size. The pattern size must be 1, 2, or 4. Equivalent to calling the same + // on the backing storage. + void fill(const void *pattern, iree_host_size_t pattern_length) { + storage_.fill(pattern, pattern_length); + } + + // Performs either a d2h, h2d or d2d transfer from a source storage to this + // storage. Equivalent to calling the same on the backing storage. + void copy_from(device_array &source_array) { + storage_.copy_from(source_array.storage_); + } + // Inverse of copy_from. + void copy_to(device_array &dest_array) { + dest_array.storage_.copy_from(storage_); + } + // Untyped access to the backing data. The array must be mappable. Specific // access modes: // * data(): Read-only access to the data. @@ -123,6 +145,15 @@ class SHORTFIN_API device_array protected: class storage storage_; + + private: + // ProgramInvocationMarshalable implementation. + void AddAsInvocationArgument(local::ProgramInvocation *inv, + local::ProgramResourceBarrier barrier) override; + static device_array CreateFromInvocationResultRef( + local::ProgramInvocation *inv, iree::vm_opaque_ref ref); + static iree_vm_ref_type_t invocation_marshalable_type(); + friend class shortfin::local::ProgramInvocationMarshalableFactory; }; } // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/dims.h b/libshortfin/src/shortfin/array/dims.h index 529aebc42..2988fb7cb 100644 --- a/libshortfin/src/shortfin/array/dims.h +++ b/libshortfin/src/shortfin/array/dims.h @@ -11,6 +11,7 @@ #include #include +#include "iree/hal/buffer_view.h" #include "shortfin/support/api.h" namespace shortfin::array { @@ -248,8 +249,8 @@ class SHORTFIN_API InlinedDims { _D dims_; }; -extern template class InlinedDims; -using Dims = InlinedDims; +extern template class InlinedDims; +using Dims = InlinedDims; } // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/dtype.cc b/libshortfin/src/shortfin/array/dtype.cc index e6b92a7e9..19cc97be9 100644 --- a/libshortfin/src/shortfin/array/dtype.cc +++ b/libshortfin/src/shortfin/array/dtype.cc @@ -6,6 +6,8 @@ #include "shortfin/array/dtype.h" +#include + #include "fmt/core.h" namespace shortfin::array { @@ -22,4 +24,21 @@ iree_device_size_t DType::compute_dense_nd_size(std::span dims) { return accum; } +DType DType::import_element_type(iree_hal_element_type_t et) { + static std::unordered_map static_canonical = + ([]() { + std::unordered_map c; + auto add = [&](DType dt) { c.emplace(std::make_pair(dt.et_, dt)); }; +#define SHORTFIN_DTYPE_HANDLE(et, ident) add(DType(et, #ident)); +#include "shortfin/array/dtypes.inl" +#undef SHORTFIN_DTYPE_HANDLE + return c; + })(); + + auto &c = static_canonical; + auto it = c.find(et); + if (it != c.end()) return it->second; + return DType(et, "opaque_imported"); +} + } // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/dtype.h b/libshortfin/src/shortfin/array/dtype.h index 090751162..eafb16a86 100644 --- a/libshortfin/src/shortfin/array/dtype.h +++ b/libshortfin/src/shortfin/array/dtype.h @@ -19,64 +19,10 @@ namespace shortfin::array { // Wraps an iree_hal_element_type into a DType like object. class SHORTFIN_API DType { public: - static DType opaque8() { - return DType(IREE_HAL_ELEMENT_TYPE_OPAQUE_8, "opaque8"); - } - static DType opaque16() { - return DType(IREE_HAL_ELEMENT_TYPE_OPAQUE_16, "opaque16"); - } - static DType opaque32() { - return DType(IREE_HAL_ELEMENT_TYPE_OPAQUE_32, "opaque32"); - } - static DType opaque64() { - return DType(IREE_HAL_ELEMENT_TYPE_OPAQUE_64, "opaque64"); - } - static DType bool8() { return DType(IREE_HAL_ELEMENT_TYPE_BOOL_8, "bool8"); } - static DType int4() { return DType(IREE_HAL_ELEMENT_TYPE_INT_4, "int4"); } - static DType sint4() { return DType(IREE_HAL_ELEMENT_TYPE_SINT_4, "sint4"); } - static DType uint4() { return DType(IREE_HAL_ELEMENT_TYPE_UINT_4, "uint4"); } - static DType int8() { return DType(IREE_HAL_ELEMENT_TYPE_INT_8, "int8"); } - static DType sint8() { return DType(IREE_HAL_ELEMENT_TYPE_SINT_8, "sint8"); } - static DType uint8() { return DType(IREE_HAL_ELEMENT_TYPE_UINT_8, "uint8"); } - static DType int16() { return DType(IREE_HAL_ELEMENT_TYPE_INT_16, "int16"); } - static DType sint16() { - return DType(IREE_HAL_ELEMENT_TYPE_SINT_16, "sint16"); - } - static DType uint16() { - return DType(IREE_HAL_ELEMENT_TYPE_UINT_16, "uint16"); - } - static DType int32() { return DType(IREE_HAL_ELEMENT_TYPE_INT_32, "int32"); } - static DType sint32() { - return DType(IREE_HAL_ELEMENT_TYPE_SINT_32, "sint32"); - } - static DType uint32() { - return DType(IREE_HAL_ELEMENT_TYPE_UINT_32, "uint32"); - } - static DType int64() { return DType(IREE_HAL_ELEMENT_TYPE_INT_64, "int64"); } - static DType sint64() { - return DType(IREE_HAL_ELEMENT_TYPE_SINT_64, "sint64"); - } - static DType uint64() { - return DType(IREE_HAL_ELEMENT_TYPE_UINT_64, "uint64"); - } - static DType float16() { - return DType(IREE_HAL_ELEMENT_TYPE_FLOAT_16, "float16"); - } - static DType float32() { - return DType(IREE_HAL_ELEMENT_TYPE_FLOAT_32, "float32"); - } - static DType float64() { - return DType(IREE_HAL_ELEMENT_TYPE_FLOAT_64, "float64"); - } - static DType bfloat16() { - return DType(IREE_HAL_ELEMENT_TYPE_BFLOAT_16, "bfloat16"); - } - static DType complex64() { - return DType(IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64, "complex64"); - } - static DType complex128() { - return DType(IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128, "complex128"); - } +#define SHORTFIN_DTYPE_HANDLE(et, ident) \ + static DType ident() { return DType(et, #ident); } +#include "shortfin/array/dtypes.inl" +#undef SHORTFIN_DTYPE_HANDLE operator iree_hal_element_type_t() const { return et_; } @@ -112,6 +58,9 @@ class SHORTFIN_API DType { bool operator==(const DType &other) const { return et_ == other.et_; } + // Imports a raw iree_hal_element_type_t from the ether. + static DType import_element_type(iree_hal_element_type_t et); + private: DType(iree_hal_element_type_t et, std::string_view name) : et_(et), name_(name) {} diff --git a/libshortfin/src/shortfin/array/dtypes.inl b/libshortfin/src/shortfin/array/dtypes.inl new file mode 100644 index 000000000..50be10461 --- /dev/null +++ b/libshortfin/src/shortfin/array/dtypes.inl @@ -0,0 +1,34 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Include file API for enumerating all known dtypes. + +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_OPAQUE_8, opaque8) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_OPAQUE_16, opaque16) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_OPAQUE_32, opaque32) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_OPAQUE_64, opaque64) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_BOOL_8, bool8) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_INT_4, int4) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_SINT_4, sint4) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_UINT_4, uint4) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_INT_8, int8) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_SINT_8, sint8) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_UINT_8, uint8) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_INT_16, int16) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_SINT_16, sint16) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_UINT_16, uint16) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_INT_32, int32) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_SINT_32, sint32) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_UINT_32, uint32) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_INT_64, int64) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_SINT_64, sint64) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_UINT_64, uint64) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_FLOAT_16, float16) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_FLOAT_32, float32) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_FLOAT_64, float64) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_BFLOAT_16, bfloat16) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64, complex64) +SHORTFIN_DTYPE_HANDLE(IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128, complex128) diff --git a/libshortfin/src/shortfin/array/storage.cc b/libshortfin/src/shortfin/array/storage.cc index fa9e0f4b8..750c040b5 100644 --- a/libshortfin/src/shortfin/array/storage.cc +++ b/libshortfin/src/shortfin/array/storage.cc @@ -35,8 +35,14 @@ storage::storage(local::ScopedDevice device, iree::hal_buffer_ptr buffer, } storage::~storage() { logging::destruct("array::storage", this); } -storage storage::AllocateDevice(ScopedDevice &device, - iree_device_size_t allocation_size) { +storage storage::import_buffer(local::ScopedDevice &device, + iree::hal_buffer_ptr buffer) { + return storage(device, std::move(buffer), + device.scope().NewTimelineResource()); +} + +storage storage::allocate_device(ScopedDevice &device, + iree_device_size_t allocation_size) { if (!device.raw_device()) { throw std::invalid_argument("Cannot allocate with a null device affinity"); } @@ -54,8 +60,8 @@ storage storage::AllocateDevice(ScopedDevice &device, device.scope().NewTimelineResource()); } -storage storage::AllocateHost(ScopedDevice &device, - iree_device_size_t allocation_size) { +storage storage::allocate_host(ScopedDevice &device, + iree_device_size_t allocation_size) { if (!device.raw_device()) { throw std::invalid_argument("Cannot allocate with a null device affinity"); } @@ -64,7 +70,8 @@ storage storage::AllocateHost(ScopedDevice &device, iree_hal_buffer_params_t params = { .usage = IREE_HAL_BUFFER_USAGE_MAPPING, .access = IREE_HAL_MEMORY_ACCESS_ALL, - .type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_HOST, + .type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_HOST | + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, .queue_affinity = device.affinity().queue_affinity(), }; if (device.affinity().queue_affinity() != 0) { @@ -76,7 +83,7 @@ storage storage::AllocateHost(ScopedDevice &device, device.scope().NewTimelineResource()); } -storage storage::Subspan(iree_device_size_t byte_offset, +storage storage::subspan(iree_device_size_t byte_offset, iree_device_size_t byte_length) { storage new_storage(device_, {}, timeline_resource_); SHORTFIN_THROW_IF_ERROR(iree_hal_buffer_subspan( @@ -84,7 +91,7 @@ storage storage::Subspan(iree_device_size_t byte_offset, return new_storage; } -void storage::Fill(const void *pattern, iree_host_size_t pattern_length) { +void storage::fill(const void *pattern, iree_host_size_t pattern_length) { device_.scope().scheduler().AppendCommandBuffer( device_, TransactionType::TRANSFER, [&](Account &account) { // Must depend on all of this buffer's use dependencies to avoid @@ -94,9 +101,8 @@ void storage::Fill(const void *pattern, iree_host_size_t pattern_length) { // write-after-write hazard. account.active_deps_extend(timeline_resource_->mutation_barrier()); - // TODO: I need to join the submission dependencies on the account - // with the timeline resource idle fence to ensure that - // write-after-access is properly sequenced. + SHORTFIN_SCHED_LOG(" : FillBuffer({})", + static_cast(buffer_.get())); SHORTFIN_THROW_IF_ERROR(iree_hal_command_buffer_fill_buffer( account.active_command_buffer(), iree_hal_make_buffer_ref( @@ -111,7 +117,7 @@ void storage::Fill(const void *pattern, iree_host_size_t pattern_length) { }); } -void storage::CopyFrom(storage &source_storage) { +void storage::copy_from(storage &source_storage) { device_.scope().scheduler().AppendCommandBuffer( device_, TransactionType::TRANSFER, [&](Account &account) { // Must depend on the source's mutation dependencies to avoid @@ -122,6 +128,9 @@ void storage::CopyFrom(storage &source_storage) { account.active_deps_extend(timeline_resource_->use_barrier()); account.active_deps_extend(timeline_resource_->mutation_barrier()); + SHORTFIN_SCHED_LOG(" : CopyBuffer({} -> {})", + static_cast(source_storage.buffer_.get()), + static_cast(buffer_.get())); SHORTFIN_THROW_IF_ERROR(iree_hal_command_buffer_copy_buffer( account.active_command_buffer(), /*source_ref=*/ @@ -129,10 +138,13 @@ void storage::CopyFrom(storage &source_storage) { /*target_ref=*/ iree_hal_make_buffer_ref(buffer_, 0, byte_length()))); - // And move our own mutation barrier to the current pending timeline + // Move our own mutation barrier to the current pending timeline // value. timeline_resource_->set_mutation_barrier( account.timeline_sem(), account.timeline_idle_timepoint()); + // And extend the source use barrier. + source_storage.timeline_resource_->use_barrier_insert( + account.timeline_sem(), account.timeline_idle_timepoint()); }); } @@ -150,7 +162,7 @@ bool storage::is_mappable_for_read_write() const { (IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE)); } -void storage::MapExplicit(mapping &mapping, iree_hal_memory_access_t access) { +void storage::map_explicit(mapping &mapping, iree_hal_memory_access_t access) { assert(access != IREE_HAL_MEMORY_ACCESS_NONE); mapping.reset(); SHORTFIN_THROW_IF_ERROR(iree_hal_buffer_map_range( @@ -189,6 +201,68 @@ std::string storage::formatted_buffer_usage() const { return std::string(sv.data, sv.size); } +void storage::AddAsInvocationArgument(local::ProgramInvocation *inv, + local::ProgramResourceBarrier barrier) { + iree::vm_opaque_ref ref; + *(&ref) = iree_hal_buffer_retain_ref(buffer_); + inv->AddArg(std::move(ref)); + + AddInvocationArgBarrier(inv, barrier); +} + +iree_vm_ref_type_t storage::invocation_marshalable_type() { + return iree_hal_buffer_type(); +} + +storage storage::CreateFromInvocationResultRef(local::ProgramInvocation *inv, + iree::vm_opaque_ref ref) { + // Steal the ref to one of our smart pointers. + // TODO: Should have an opaque_ref::release(). + iree::hal_buffer_ptr buffer = + iree::hal_buffer_ptr::steal_reference(iree_hal_buffer_deref(*ref.get())); + (&ref)->ptr = nullptr; + return ImportInvocationResultStorage(inv, std::move(buffer)); +} + +storage storage::ImportInvocationResultStorage(local::ProgramInvocation *inv, + iree::hal_buffer_ptr buffer) { + local::ScopedDevice device = + local::ScopedDevice(*inv->scope(), inv->device_selection()); + auto imported_storage = storage::import_buffer(device, std::move(buffer)); + + auto coarse_signal = inv->coarse_signal(); + if (coarse_signal.first) { + SHORTFIN_SCHED_LOG("Storage buffer {}: Ready barrier {}@{}", + static_cast(imported_storage.buffer_.get()), + static_cast(coarse_signal.first), + coarse_signal.second); + imported_storage.timeline_resource_->set_mutation_barrier( + coarse_signal.first, coarse_signal.second); + imported_storage.timeline_resource_->use_barrier_insert( + coarse_signal.first, coarse_signal.second); + } + + return imported_storage; +} + +void storage::AddInvocationArgBarrier(local::ProgramInvocation *inv, + local::ProgramResourceBarrier barrier) { + switch (barrier) { + case ProgramResourceBarrier::DEFAULT: + case ProgramResourceBarrier::READ: + inv->wait_insert(timeline_resource_->mutation_barrier()); + inv->DeviceSelect(device_.affinity()); + break; + case ProgramResourceBarrier::WRITE: + inv->wait_insert(timeline_resource_->mutation_barrier()); + inv->wait_insert(timeline_resource_->use_barrier()); + inv->DeviceSelect(device_.affinity()); + break; + case ProgramResourceBarrier::NONE: + break; + } +} + std::string storage::to_s() const { return fmt::format("", static_cast(buffer_.get()), byte_length()); diff --git a/libshortfin/src/shortfin/array/storage.h b/libshortfin/src/shortfin/array/storage.h index 0db73d28f..a065905aa 100644 --- a/libshortfin/src/shortfin/array/storage.h +++ b/libshortfin/src/shortfin/array/storage.h @@ -9,6 +9,7 @@ #include +#include "shortfin/local/program_interfaces.h" #include "shortfin/local/scope.h" #include "shortfin/support/api.h" @@ -70,7 +71,7 @@ class SHORTFIN_API mapping { }; // Array storage backed by an IREE buffer of some form. -class SHORTFIN_API storage { +class SHORTFIN_API storage : public local::ProgramInvocationMarshalable { public: ~storage(); local::ScopedDevice &device() { return device_; } @@ -78,32 +79,35 @@ class SHORTFIN_API storage { const local::ScopedDevice &device() const { return device_; } local::Scope &scope() const { return device_.scope(); } + static storage import_buffer(local::ScopedDevice &device, + iree::hal_buffer_ptr buffer); + // Allocates device storage, compatible with the given device affinity. // By default, this will be IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE. - static storage AllocateDevice(local::ScopedDevice &device, - iree_device_size_t allocation_size); + static storage allocate_device(local::ScopedDevice &device, + iree_device_size_t allocation_size); // Allocates host storage, compatible with the given device affinity. // By default, if there are any affinity bits set in the device, then // the storage will be device visible and have permitted usage for // transfers. This default policy can be overriden based on device defaults // or explicit options. - static storage AllocateHost(local::ScopedDevice &device, - iree_device_size_t allocation_size); + static storage allocate_host(local::ScopedDevice &device, + iree_device_size_t allocation_size); // Creates a subspan view of the current storage given a byte offset and // length. The returned storage shares the underlying allocation and // scheduling control block. - storage Subspan(iree_device_size_t byte_offset, + storage subspan(iree_device_size_t byte_offset, iree_device_size_t byte_length); // Enqueues a fill of the storage with an arbitrary pattern of the given // size. The pattern size must be 1, 2, or 4. - void Fill(const void *pattern, iree_host_size_t pattern_length); + void fill(const void *pattern, iree_host_size_t pattern_length); // Performs either a d2h, h2d or d2d transfer from a source storage to this // storage. - void CopyFrom(storage &source_storage); + void copy_from(storage &source_storage); iree_device_size_t byte_length() const { return iree_hal_buffer_byte_length(buffer_.get()); @@ -124,33 +128,33 @@ class SHORTFIN_API storage { bool is_mappable_for_read_write() const; // Maps the memory for access from a host pointer using a scoped mapping. - void MapExplicit(mapping &mapping, iree_hal_memory_access_t access); + void map_explicit(mapping &mapping, iree_hal_memory_access_t access); // Maps the memory for read/write access, preserving any contents. - mapping MapReadWrite() { + mapping map_read_write() { mapping m; - MapExplicit(m, IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE); + map_explicit(m, IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE); return m; } // Maps the memory for discard write. This is used if populating an initial // buffer. - mapping MapWriteDiscard() { + mapping map_write_discard() { mapping m; - MapExplicit(m, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE); + map_explicit(m, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE); return m; } // Maps the memory for read-only access. - mapping MapRead() { + mapping map_read() { mapping m; - MapExplicit(m, IREE_HAL_MEMORY_ACCESS_READ); + map_explicit(m, IREE_HAL_MEMORY_ACCESS_READ); return m; } - const mapping MapRead() const { + const mapping map_read() const { mapping m; - const_cast(this)->MapExplicit(m, IREE_HAL_MEMORY_ACCESS_READ); + const_cast(this)->map_explicit(m, IREE_HAL_MEMORY_ACCESS_READ); return m; } @@ -161,15 +165,39 @@ class SHORTFIN_API storage { // underlying device references alive as needed). operator iree_hal_buffer_t *() { return buffer_; } + iree_allocator_t host_allocator() { + return timeline_resource_->host_allocator(); + } + private: storage(local::ScopedDevice device, iree::hal_buffer_ptr buffer, local::detail::TimelineResource::Ref timeline_resource); + // ProgramInvocationMarshalable implementation. + void AddAsInvocationArgument(local::ProgramInvocation *inv, + local::ProgramResourceBarrier barrier) override; + static storage CreateFromInvocationResultRef(local::ProgramInvocation *inv, + iree::vm_opaque_ref ref); + static iree_vm_ref_type_t invocation_marshalable_type(); + + // Adds any necessary wait barriers to the invocation on behalf of this + // storage. + void AddInvocationArgBarrier(local::ProgramInvocation *inv, + local::ProgramResourceBarrier barrier); + + // Imports a raw hal buffer from an invocation as a storage, attaching any + // needed barriers. + static storage ImportInvocationResultStorage(local::ProgramInvocation *inv, + iree::hal_buffer_ptr buffer); + // The timeline resource holds the back reference to the owning scope, // which keeps all devices alive. Buffers must be destroyed before devices, // so this must be declared first. local::detail::TimelineResource::Ref timeline_resource_; iree::hal_buffer_ptr buffer_; local::ScopedDevice device_; + + friend class shortfin::local::ProgramInvocationMarshalableFactory; + friend class device_array; }; // Wraps an untyped mapping, providing typed access. diff --git a/libshortfin/src/shortfin/local/async.cc b/libshortfin/src/shortfin/local/async.cc index a5e005768..934fc0d46 100644 --- a/libshortfin/src/shortfin/local/async.cc +++ b/libshortfin/src/shortfin/local/async.cc @@ -173,7 +173,7 @@ void Future::ThrowFailureWithLockHeld() { if (!state_->done_) { throw std::logic_error("Cannot get result from Future that is not done"); } - SHORTFIN_THROW_IF_ERROR(state_->failure_status_); + SHORTFIN_THROW_IF_ERROR(state_->failure_status_.ConsumeStatus()); } } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/async.h b/libshortfin/src/shortfin/local/async.h index 094c90707..5046f5567 100644 --- a/libshortfin/src/shortfin/local/async.h +++ b/libshortfin/src/shortfin/local/async.h @@ -166,6 +166,16 @@ class SHORTFIN_API TypedFuture : public Future { return *this; } + // Futures are non-nullable, so construct/assign from an rvalue reference + // is just a copy and does not clear the original. + TypedFuture(TypedFuture &&other) : Future(other.state_) { Retain(); } + TypedFuture &operator=(TypedFuture &&other) { + other.Retain(); + Release(); + state_ = other.state_; + return *this; + } + void set_result(ResultTy result) { iree::slim_mutex_lock_guard g(state_->lock_); if (state_->done_) { diff --git a/libshortfin/src/shortfin/local/device.h b/libshortfin/src/shortfin/local/device.h index ccfc30ef6..48482f63c 100644 --- a/libshortfin/src/shortfin/local/device.h +++ b/libshortfin/src/shortfin/local/device.h @@ -194,6 +194,7 @@ class SHORTFIN_API DeviceAffinity { return result; } + operator bool() const { return device_ != nullptr; } Device *device() const { return device_; } iree_hal_queue_affinity_t queue_affinity() const { return queue_affinity_; } // Returns the lowest queue ordinal in the affinity set. If there are no diff --git a/libshortfin/src/shortfin/local/messaging.h b/libshortfin/src/shortfin/local/messaging.h index f006775fe..fc1f3173b 100644 --- a/libshortfin/src/shortfin/local/messaging.h +++ b/libshortfin/src/shortfin/local/messaging.h @@ -121,9 +121,9 @@ class SHORTFIN_API Message { // sized field that the allocator can use at it sees fit. Both fields // are managed within a lock_ scope and are optimized for single threaded // access and cross-thread transfers with coarse references. + mutable iree::slim_mutex lock_; mutable intptr_t ref_data_ = 1; mutable detail::MessageRefOwner owner_; - mutable iree::slim_mutex lock_; friend struct detail::MessageRefOwner; }; diff --git a/libshortfin/src/shortfin/local/program.cc b/libshortfin/src/shortfin/local/program.cc index cba725096..c7a045eeb 100644 --- a/libshortfin/src/shortfin/local/program.cc +++ b/libshortfin/src/shortfin/local/program.cc @@ -8,13 +8,85 @@ #include "fmt/core.h" #include "fmt/std.h" +#include "iree/modules/hal/module.h" #include "iree/vm/bytecode/module.h" +#include "shortfin/local/scope.h" #include "shortfin/local/system.h" +#include "shortfin/support/logging.h" namespace shortfin::local { -ProgramModule ProgramModule::Load(System& system, - const std::filesystem::path& path, +namespace { +void GetVmModuleExports(iree_vm_module_t *vm_module, + std::vector &exports) { + auto sig = iree_vm_module_signature(vm_module); + for (iree_host_size_t i = 0; i < sig.export_function_count; ++i) { + iree_vm_function_t f; + SHORTFIN_THROW_IF_ERROR(iree_vm_module_lookup_function_by_ordinal( + vm_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &f)); + exports.emplace_back(to_string_view(iree_vm_function_name(&f))); + } +} +} // namespace + +// -------------------------------------------------------------------------- // +// ProgramFunction +// -------------------------------------------------------------------------- // + +ProgramFunction::ProgramFunction( + std::shared_ptr scope, iree::vm_context_ptr vm_context, + iree_vm_function_t vm_function, + std::optional invocation_model) + : scope_(std::move(scope)), + vm_context_(std::move(vm_context)), + vm_function_(vm_function), + invocation_model_(invocation_model + ? *invocation_model + : GetInvocationModelFromFunction(vm_function)) {} + +ProgramInvocationModel ProgramFunction::GetInvocationModelFromFunction( + iree_vm_function_t &f) { + iree_string_view_t invocation_model_sv = + iree_vm_function_lookup_attr_by_name(&f, IREE_SV("iree.abi.model")); + if (iree_string_view_equal(invocation_model_sv, IREE_SV("coarse-fences"))) { + return ProgramInvocationModel::COARSE_FENCES; + } else if (invocation_model_sv.size == 0) { + return ProgramInvocationModel::NONE; + } else { + logging::warn("Unknown function invocation model '{}': '{}'", + to_string_view(iree_vm_function_name(&f)), + to_string_view(invocation_model_sv)); + return ProgramInvocationModel::UNKNOWN; + } +} + +std::string_view ProgramFunction::name() const { + if (!*this) return {}; + return to_string_view(iree_vm_function_name(&vm_function_)); +} + +std::string_view ProgramFunction::calling_convention() const { + if (!*this) return {}; + return to_string_view( + iree_vm_function_signature(&vm_function_).calling_convention); +} + +ProgramInvocation::Ptr ProgramFunction::CreateInvocation() { + return ProgramInvocation::New(scope_, vm_context_, vm_function_, + invocation_model_); +} + +std::string ProgramFunction::to_s() const { + if (!*this) return std::string("ProgramFunction(NULL)"); + return fmt::format("ProgramFunction({}: {})", name(), calling_convention()); +} + +// -------------------------------------------------------------------------- // +// ProgramModule +// -------------------------------------------------------------------------- // + +ProgramModule ProgramModule::Load(System &system, + const std::filesystem::path &path, bool mmap) { iree::file_contents_ptr contents; iree_file_read_flags_t flags = @@ -53,4 +125,412 @@ std::string ProgramModule::to_s() const { sig.version, fmt::join(exports, ", ")); } +std::vector ProgramModule::exports() const { + std::vector exports; + GetVmModuleExports(vm_module_, exports); + return exports; +} + +// -------------------------------------------------------------------------- // +// Program +// -------------------------------------------------------------------------- // + +Program Program::Load(std::shared_ptr scope, + std::span modules, Options options) { + std::vector all_modules; + std::vector raw_devices; + + // By default, bind all devices in the scope in order to the program. + for (Device *d : scope->raw_devices()) { + raw_devices.push_back(d->hal_device()); + } + + // Add a HAL module. + // TODO: at some point may want to change this to something similar to + // what the tooling does in iree_tooling_resolve_modules - it uses + // iree_vm_module_enumerate_dependencies to walk the dependencies and add the + // required modules only as needed. to start you could use it just to see if + // the hal is used, but as you add other module types for exposing sharkfin + // functionality (or module versions; iree_vm_module_dependency_t has the + // minimum version required so you can switch between them, and whether they + // are optional/required). + auto &system = scope->system(); + iree::vm_module_ptr hal_module; + SHORTFIN_THROW_IF_ERROR( + iree_hal_module_create(system.vm_instance(), raw_devices.size(), + raw_devices.data(), IREE_HAL_MODULE_FLAG_NONE, + system.host_allocator(), hal_module.for_output())); + all_modules.push_back(hal_module); + + // Add explicit modules. + for (auto &pm : modules) { + all_modules.push_back(pm.vm_module()); + } + + // Create the context. + iree::vm_context_ptr context; + iree_vm_context_flags_t flags = IREE_VM_CONTEXT_FLAG_CONCURRENT; + if (options.trace_execution) flags |= IREE_VM_CONTEXT_FLAG_TRACE_EXECUTION; + SHORTFIN_THROW_IF_ERROR(iree_vm_context_create_with_modules( + system.vm_instance(), flags, all_modules.size(), all_modules.data(), + system.host_allocator(), context.for_output())); + + return Program(std::move(scope), std::move(context)); +} + +std::optional Program::LookupFunction(std::string_view name) { + // By convention, we currently name our coarse-fences function variants + // as ending in "$async". These are the ones we want but it is inconvenient. + // Therefore, we probe for that first. + // TODO: We should add attributes to the function that better describe this + // relationship. + iree_vm_function_t f; + if (!name.ends_with("$async")) { + std::string async_name(name); + async_name.append("$async"); + iree_status_t status = iree_vm_context_resolve_function( + vm_context_, to_iree_string_view(async_name), &f); + if (iree_status_is_ok(status)) { + // TODO: Torch import is not setting the coarse-fences abi.model on + // its functions. Get it from there instead of just assuming based on + // name. + return ProgramFunction(scope_, vm_context_, f, + ProgramInvocationModel::COARSE_FENCES); + } else if (!iree_status_is_not_found(status)) { + SHORTFIN_THROW_IF_ERROR(status); + } + } + + // Resolve the exactly named function. + iree_status_t status = iree_vm_context_resolve_function( + vm_context_, to_iree_string_view(name), &f); + if (iree_status_is_not_found(status)) return {}; + SHORTFIN_THROW_IF_ERROR(status); + return ProgramFunction(scope_, vm_context_, f); +} + +ProgramFunction Program::LookupRequiredFunction(std::string_view name) { + auto f = LookupFunction(name); + if (!f) { + throw std::invalid_argument( + fmt::format("Function '{}' not found in program. Available exports: {}", + name, fmt::join(exports(), ", "))); + } + return std::move(*f); +} + +std::vector Program::exports() const { + std::vector results; + + // Iterate in reverse since "user modules" are typically last. + int module_count = iree_vm_context_module_count(vm_context_); + for (int i = module_count - 1; i >= 0; --i) { + auto vm_module = iree_vm_context_module_at(vm_context_, i); + std::string_view module_name = + to_string_view(iree_vm_module_name(vm_module)); + std::vector names; + GetVmModuleExports(vm_module, names); + for (auto &name : names) { + results.push_back(fmt::format("{}.{}", module_name, name)); + } + } + return results; +} + +// -------------------------------------------------------------------------- // +// ProgramInvocation +// -------------------------------------------------------------------------- // + +iree_vm_list_t *ProgramInvocation::arg_list() { + // The arg list is located immediately after this, allocated as a trailing + // data structure. + return reinterpret_cast(reinterpret_cast(this) + + sizeof(*this)); +} + +void ProgramInvocation::Deleter::operator()(ProgramInvocation *inst) { + inst->~ProgramInvocation(); + uint8_t *memory = static_cast(static_cast(inst)); + + // Trailing arg list and result list. The arg list pointer is only available + // at construction, so we use the knowledge that it is stored right after + // the object. The result_list_ is available for the life of the invocation. + iree_vm_list_deinitialize(static_cast( + static_cast(memory + sizeof(ProgramInvocation)))); + iree_vm_list_deinitialize(inst->result_list_); + + // Was allocated in New as a uint8_t[] so delete it by whence it came. + delete[] memory; +} + +ProgramInvocation::ProgramInvocation() = default; +ProgramInvocation::~ProgramInvocation() { + if (!scheduled()) { + // This instance was dropped on the floor before scheduling. + // Clean up the initialization parameters. + iree::vm_context_ptr drop = + iree::vm_context_ptr::steal_reference(state.params.context); + } +} + +ProgramInvocation::Ptr ProgramInvocation::New( + std::shared_ptr scope, iree::vm_context_ptr vm_context, + iree_vm_function_t &vm_function, ProgramInvocationModel invocation_model) { + auto sig = iree_vm_function_signature(&vm_function); + iree_host_size_t arg_count; + iree_host_size_t result_count; + SHORTFIN_THROW_IF_ERROR(iree_vm_function_call_count_arguments_and_results( + &sig, &arg_count, &result_count)); + + // Compute size of trailing arg/result storage. + auto variant_type_def = iree_vm_make_undefined_type_def(); + iree_host_size_t arg_storage_size = + iree_vm_list_storage_size(&variant_type_def, arg_count); + iree_host_size_t result_storage_size = + iree_vm_list_storage_size(&variant_type_def, result_count); + + // Allocate storage for the ProgramInvocation, arg, result list and placement + // new the ProgramInvocation into the storage area. + std::unique_ptr inst_storage( + new uint8_t[sizeof(ProgramInvocation) + arg_storage_size + + result_storage_size]); + new (inst_storage.get()) ProgramInvocation(); + + // Initialize trailing lists. Abort on failure since this is a bug and we + // would otherwise leak. + iree_vm_list_t *arg_list; + iree_vm_list_t *result_list; + IREE_CHECK_OK(iree_vm_list_initialize( + {.data = inst_storage.get() + sizeof(ProgramInvocation), + .data_length = arg_storage_size}, + &variant_type_def, arg_count, &arg_list)); + IREE_CHECK_OK(iree_vm_list_initialize( + {.data = + inst_storage.get() + sizeof(ProgramInvocation) + arg_storage_size, + .data_length = result_storage_size}, + &variant_type_def, result_count, &result_list)); + + Ptr inst(static_cast( + static_cast(inst_storage.release())), + Deleter()); + inst->scope_ = std::move(scope); + inst->state.params.context = + vm_context.release(); // Ref transfer to ProgramInvocation. + inst->state.params.function = vm_function; + inst->state.params.invocation_model = invocation_model; + inst->result_list_ = result_list; + return inst; +} + +void ProgramInvocation::CheckNotScheduled() { + if (scheduled()) { + throw std::logic_error("Cannot mutate an invocation once scheduled."); + } +} + +void ProgramInvocation::AddArg(iree::vm_opaque_ref ref) { + CheckNotScheduled(); + SHORTFIN_THROW_IF_ERROR(iree_vm_list_push_ref_move(arg_list(), &ref)); +} + +void ProgramInvocation::AddArg(iree_vm_ref_t *ref) { + CheckNotScheduled(); + SHORTFIN_THROW_IF_ERROR(iree_vm_list_push_ref_retain(arg_list(), ref)); +} + +iree_status_t ProgramInvocation::FinalizeCallingConvention( + iree_vm_list_t *arg_list, iree_vm_function_t &function, + ProgramInvocationModel invocation_model) { + // Handle post-processing invocation model setup. + if (invocation_model == ProgramInvocationModel::COARSE_FENCES) { + // If we have a device_selection, set up to signal the leader account. + if (device_selection_) { + ScopedDevice scoped_device(*scope(), device_selection_); + auto &sched_account = + scope()->scheduler().GetDefaultAccount(scoped_device); + iree_hal_fence_t *wait_fence = this->wait_fence(); + iree_hal_semaphore_t *timeline_sem = sched_account.timeline_sem(); + uint64_t timeline_now = sched_account.timeline_idle_timepoint(); + SHORTFIN_SCHED_LOG("Invocation {}: Wait on account timeline {}@{}", + static_cast(this), + static_cast(timeline_sem), timeline_now); + IREE_RETURN_IF_ERROR( + iree_hal_fence_insert(wait_fence, timeline_sem, timeline_now)); + signal_sem_ = sched_account.timeline_sem(); + signal_timepoint_ = sched_account.timeline_acquire_timepoint(); + } + + // Push wait fence (or null if no wait needed). + ::iree::vm::ref wait_ref; + if (wait_fence_) { + ::iree::vm::retain_ref(wait_fence()); + } + IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(arg_list, wait_ref)); + + // Create and push signal fence (or null if no signal needed). + ::iree::vm::ref signal_ref; + if (signal_sem_) { + SHORTFIN_SCHED_LOG("Invocation {}: Set signal {}@{}", + static_cast(this), + static_cast(signal_sem_), signal_timepoint_); + IREE_RETURN_IF_ERROR( + iree_hal_fence_create_at(signal_sem_, signal_timepoint_, + scope()->host_allocator(), &signal_ref)); + } + IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(arg_list, signal_ref)); + } else { + logging::warn( + "Invoking function '{}' with unknown or synchronous invocation model " + "is not fully supported", + to_string_view(iree_vm_function_name(&function))); + } + + return iree_ok_status(); +} + +ProgramInvocation::Future ProgramInvocation::Invoke( + ProgramInvocation::Ptr invocation) { + invocation->CheckNotScheduled(); + + Worker &worker = invocation->scope_->worker(); + // We're about to overwrite the instance level storage for params, so move + // it to the stack and access there. + Params params = invocation->state.params; + + auto schedule = [](ProgramInvocation *raw_invocation, Worker *worker, + iree_vm_context_t *owned_context, + iree_vm_function_t function, + ProgramInvocationModel invocation_model, + std::optional failure_future) { + auto complete_callback = + [](void *user_data, iree_loop_t loop, iree_status_t status, + iree_vm_list_t *outputs) noexcept -> iree_status_t { + // Async invocation helpfully gives us a retained reference to the + // outputs, but we already have one statically on the + // ProgramInvocation. So release this one, which makes it safe to + // deallocate the ProgramInvocation at any point after this (there + // must be no live references to inputs/outputs when the + // ProgramInvocation::Ptr deleter is invoked). + iree::vm_list_ptr::steal_reference(outputs); + + // Repatriate the ProgramInvocation. + ProgramInvocation::Ptr invocation( + static_cast(user_data)); + ProgramInvocation *raw_invocation = invocation.get(); + if (iree_status_is_ok(status)) { + raw_invocation->future_->set_result(std::move(invocation)); + } else { + raw_invocation->future_->set_failure(status); + } + + // Must release the future from the invocation to break the + // circular reference (we are setting the invocation as the result + // of the future). + raw_invocation->future_.reset(); + + return iree_ok_status(); + }; + + ProgramInvocation::Ptr invocation(raw_invocation); + iree_status_t status = iree_ok_status(); + + // Multiple steps needed to schedule need to all exit via the same + // path. + if (iree_status_is_ok(status)) { + status = invocation->scope()->scheduler().FlushWithStatus(); + } + if (iree_status_is_ok(status)) { + status = invocation->FinalizeCallingConvention( + invocation->arg_list(), function, invocation_model); + } + if (iree_status_is_ok(status)) { + status = iree_vm_async_invoke(worker->loop(), + &invocation->state.async_invoke_state, + owned_context, function, + /*flags=*/IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/nullptr, + /*inputs=*/invocation->arg_list(), + /*outputs=*/invocation->result_list_, + iree_allocator_system(), +complete_callback, + /*user_data=*/invocation.get()); + } + + // Regardless of status, the context reference we were holding is no + // longer needed. Drop it on the floor. + iree::vm_context_ptr::steal_reference(owned_context); + + // On success, then the complete callback takes ownership of the + // invocation, so we release it here and return. We have to treat + // the invocation as possibly deallocated at this point, since the + // async invocation may have finished already. + if (iree_status_is_ok(status)) { + invocation.release(); + } else if (failure_future) { + // Requested to set any failure on the future. + failure_future->set_failure(status); + } else { + // Synchronous: just throw. + SHORTFIN_THROW_IF_ERROR(status); + } + }; + + // Transition to the scheduled state. + invocation->future_.emplace(&worker); + auto fork_future = *invocation->future_; + invocation->scheduled_ = true; + + if (&worker == Worker::GetCurrent()) { + // On the same worker: fast-path directly to the loop. + schedule(invocation.release(), &worker, params.context, params.function, + params.invocation_model, /*failure_future=*/{}); + } else { + // Cross worker coordination: submit an external task to bootstrap. + auto bound_schedule = + std::bind(schedule, invocation.release(), &worker, params.context, + params.function, params.invocation_model, + /*failure_future=*/fork_future); + worker.CallThreadsafe(bound_schedule); + } + + return fork_future; +} + +iree_host_size_t ProgramInvocation::results_size() { + return iree_vm_list_size(result_list_); +} + +iree::vm_opaque_ref ProgramInvocation::result_ref(iree_host_size_t i) { + iree::vm_opaque_ref out_value; + auto status = iree_vm_list_get_ref_retain(result_list_, i, &out_value); + if (iree_status_is_failed_precondition(status)) return {}; + SHORTFIN_THROW_IF_ERROR(status, "accessing invocation result"); + return out_value; +} + +iree_hal_fence_t *ProgramInvocation::wait_fence() { + if (!wait_fence_) { + wait_fence_ = scope_->scheduler().NewFence(); + } + return wait_fence_.get(); +} + +void ProgramInvocation::wait_insert(iree_hal_semaphore_list_t sem_list) { + iree_hal_fence_t *f = wait_fence(); + for (iree_host_size_t i = 0; i < sem_list.count; ++i) { + SHORTFIN_SCHED_LOG("Invocation {}: Wait on {}@{}", + static_cast(this), + static_cast(sem_list.semaphores[i]), + sem_list.payload_values[i]); + SHORTFIN_THROW_IF_ERROR(iree_hal_fence_insert(f, sem_list.semaphores[i], + sem_list.payload_values[i])); + } +} + +void ProgramInvocation::DeviceSelect(DeviceAffinity device_affinity) { + CheckNotScheduled(); + SHORTFIN_SCHED_LOG("Invocation {}: DeviceSelect {}", + static_cast(this), device_affinity.to_s()); + device_selection_ |= device_affinity; +} + } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/program.h b/libshortfin/src/shortfin/local/program.h index 40637a768..e701b0489 100644 --- a/libshortfin/src/shortfin/local/program.h +++ b/libshortfin/src/shortfin/local/program.h @@ -8,15 +8,205 @@ #define SHORTFIN_LOCAL_PROGRAM_H #include +#include +#include #include +#include +#include "shortfin/local/async.h" +#include "shortfin/local/device.h" +#include "shortfin/local/program_interfaces.h" +#include "shortfin/local/worker.h" #include "shortfin/support/api.h" #include "shortfin/support/iree_helpers.h" namespace shortfin::local { +class SHORTFIN_API Scope; class SHORTFIN_API System; +enum class ProgramInvocationModel { + // Uses the coarse-fences invocation model. In this model, the last two + // arguments are a wait and signal fence, which are used for function-level + // scheduling. + COARSE_FENCES, + // The function was not annotated with an invocation model. + NONE, + // The function is not annotated or is simple/synchronous. + UNKNOWN, +}; + +// State related to making an invocation of a function on a program. +// +// Since ownership of this object is transferred to the loop/callback and +// internal pointers into it must remain stable, it is only valid to heap +// allocate it. +class SHORTFIN_API ProgramInvocation { + struct Deleter { + void operator()(ProgramInvocation *); + }; + + public: + // The fact that we traffic in invocation pointers based on unique_ptr + // is incidental. By cloaking its public interface this way, we use the + // unique_ptr machinery but template meta-programming that is specialized + // for unique_ptr sees this as a bespoke class (which is what we want because + // ownership semantics are special). + class Ptr : private std::unique_ptr { + public: + using unique_ptr::unique_ptr; + using unique_ptr::operator=; + using unique_ptr::operator->; + using unique_ptr::operator bool; + using unique_ptr::get; + using unique_ptr::release; + }; + static_assert(sizeof(Ptr) == sizeof(void *)); + using Future = TypedFuture; + + static Ptr New(std::shared_ptr scope, iree::vm_context_ptr vm_context, + iree_vm_function_t &vm_function, + ProgramInvocationModel invocation_model); + ProgramInvocation(const ProgramInvocation &) = delete; + ProgramInvocation &operator=(const ProgramInvocation &) = delete; + ProgramInvocation &operator=(ProgramInvocation &&) = delete; + ProgramInvocation(ProgramInvocation &&inv) = delete; + ~ProgramInvocation(); + + // Whether the ProgramInvocation has entered the scheduled state. Once + // scheduled, arguments and initialization parameters can no longer be + // accessed. + bool scheduled() const { return scheduled_; } + + // The scope this invocation was scheduled against. + Scope *scope() const { return scope_.get(); } + + // Adds wait barriers to the invocation. For coarse fences invocations, these + // will cause execution of the function to wait until all sempahores added + // thusly are satisfied. + void wait_insert(iree_hal_semaphore_list_t sem_list); + + // Adds a marshalable argument with a configurable concurrency barrier. + void AddArg(ProgramInvocationMarshalable &marshalable, + ProgramResourceBarrier barrier = ProgramResourceBarrier::READ); + + // Adds a ref object argument. This low level interface directly adds a + // reference object and does not manipulate any execution barriers. + void AddArg(iree::vm_opaque_ref ref); // Moves a reference in. + void AddArg(iree_vm_ref_t *ref); // Borrows the reference. + + // Transfers ownership of an invocation and schedules it on worker, returning + // a future that will resolve to the owned invocation upon completion. + static ProgramInvocation::Future Invoke(ProgramInvocation::Ptr invocation); + + // Gets the number of outputs. + iree_host_size_t results_size(); + + // Gets the i'th result as an opaque ref object. Returns a null ref if the + // result is a primitive. Outputs accessed in this way are not marshaled + // nor do they have concurrency barriers applied. + iree::vm_opaque_ref result_ref(iree_host_size_t i); + + // As arguments are processed, the device they are associated with should be + // passed here. The accumulation of these will drive the selection of the + // scheduling account used for the invocation timeline. In the absence of + // a specific directive, all arguments implicated in scheduling (i.e. + // excepting those with ProgramResourceBarrier::NONE) must be on the same + // logical device and only differ by queue affinity. + // This method will raise an exception if the implied semantics are violated. + void DeviceSelect(DeviceAffinity device_affinity); + + // Selected device affinity used for scheduling. + const DeviceAffinity &device_selection() { return device_selection_; } + + // If this invocation provides coarse signaling of result availability, + // the semaphore and timepoint are returned here. If the semaphore is null, + // then coarse signaling is not available. + // Valid after invocation has been scheduled. + std::pair coarse_signal() { + return std::make_pair(signal_sem_, signal_timepoint_); + } + + private: + ProgramInvocation(); + void CheckNotScheduled(); + + // Returns a pointer to the trailing arg list. + iree_vm_list_t *arg_list(); + + // Accesses the invocation owned wait fence, creating it if needed. + iree_hal_fence_t *wait_fence(); + + // Called as part of scheduling to finalize the calling convention and + // invocation model after user arguments have been added. Because this is + // potentially run in a foreign callback context, it uses iree_status_t + // error reporting vs exceptions. + iree_status_t FinalizeCallingConvention( + iree_vm_list_t *arg_list, iree_vm_function_t &function, + ProgramInvocationModel invocation_model); + + // Parameters needed to make the async call are stored at construction time + // up until the point the call is made in the params union. When invoking, + // these will be copied to the stack and passed to the async invocation, + // which initializes the async_invoke_state. Phasing it like this saves + // memory that would otherwise be retained for the life of the invocation. + // This must not contain entities that require destruction or cannot be + // trivially copied. + struct Params { + // Context is retained upon construction and released when scheduled. + iree_vm_context_t *context; + iree_vm_function_t function; + ProgramInvocationModel invocation_model; + }; + union State { + State() { new (¶ms) Params(); } + ~State() {} + Params params; + iree_vm_async_invoke_state_t async_invoke_state; + } state; + + std::shared_ptr scope_; + iree_vm_list_t *result_list_ = nullptr; + std::optional future_; + iree::hal_fence_ptr wait_fence_; + iree_hal_semaphore_t *signal_sem_ = nullptr; + uint64_t signal_timepoint_ = 0; + DeviceAffinity device_selection_; + bool scheduled_ = false; +}; + +// References a function in a Program. +class SHORTFIN_API ProgramFunction { + public: + operator bool() const { return vm_context_; } + + std::string_view name() const; + std::string_view calling_convention() const; + ProgramInvocationModel invocation_model() const { return invocation_model_; } + + ProgramInvocation::Ptr CreateInvocation(); + + std::string to_s() const; + + operator iree_vm_context_t *() { return vm_context_.get(); } + operator iree_vm_function_t &() { return vm_function_; } + + private: + ProgramFunction(std::shared_ptr scope, iree::vm_context_ptr vm_context, + iree_vm_function_t vm_function, + std::optional invocation_model = {}); + + static ProgramInvocationModel GetInvocationModelFromFunction( + iree_vm_function_t &f); + + // The context that this function was resolved against. + std::shared_ptr scope_; + iree::vm_context_ptr vm_context_; + iree_vm_function_t vm_function_; + ProgramInvocationModel invocation_model_; + friend class Program; +}; + // High level API for working with program modules. Think of a module as // a shared library in a traditional Unix system: // @@ -36,13 +226,16 @@ class SHORTFIN_API System; class SHORTFIN_API ProgramModule { public: std::string to_s() const; - iree_vm_module_t* vm_module() const { return vm_module_; } + iree_vm_module_t *vm_module() const { return vm_module_; } std::string_view name() const; // Loads a dynamic bytecode module (VMFB) from a path on the file system. - static ProgramModule Load(System& system, const std::filesystem::path& path, + static ProgramModule Load(System &system, const std::filesystem::path &path, bool mmap = true); + // Gets the name of all exported functions. + std::vector exports() const; + protected: explicit ProgramModule(iree::vm_module_ptr vm_module) : vm_module_(std::move(vm_module)) {} @@ -52,28 +245,48 @@ class SHORTFIN_API ProgramModule { }; // Programs consist of ProgramModules instantiated together and capable of -// having functions invoked on them. While it is possible to construct -// programs that do not depend on device-associated state, the dominant -// use case is for programs that are compiled to operate against the device -// HAL with a list of concrete devices. Such programs are constructed from -// a Scope. +// having functions invoked on them. While the underlying programming model +// is a bit broader and can be exploited in various advanced way, generally, +// a program should be thought of as a fiber, and it is therefore bound to +// a Scope, which provides a logical thread of execution. By default, all +// invocations will take place in logical order (there are certain ways to +// violate this constraint safely that are provided for separately). // -// While the concurrency model for programs is technically a bit broader, the -// intended use is for them to be interacted with on a single Worker in a -// non-blocking fashion. There are many advanced ways that programs can be -// constructed to straddle devices, scopes, and workers, but that is left as -// an advanced use case. +// The program will source any needed parameters from the System and it will +// make an effort to cache them for proper locality on individual devices +// (TODO: make this actually true). class SHORTFIN_API Program { public: struct Options { + Options() {} + // Enables program-wide execution tracing (to stderr). bool trace_execution = false; }; + // Loads a program attached to a scope with a list of user provided modules + // and options. + static Program Load(std::shared_ptr scope, + std::span modules, + Options options = {}); + + // Looks up a public function by fully qualified name (i.e. module.function). + // Returns nothing if not found. + std::optional LookupFunction(std::string_view name); + + // Looks up a public function by fully qualified name, throwing an + // invalid_argument exception on failure to find. + ProgramFunction LookupRequiredFunction(std::string_view name); + + // Gets the name of all exported functions. + std::vector exports() const; + private: - explicit Program(iree::vm_context_ptr context) - : context_(std::move(context)) {} - iree::vm_context_ptr context_; + explicit Program(std::shared_ptr scope, + iree::vm_context_ptr vm_context) + : scope_(std::move(scope)), vm_context_(std::move(vm_context)) {} + std::shared_ptr scope_; + iree::vm_context_ptr vm_context_; friend class Scope; }; diff --git a/libshortfin/src/shortfin/local/program_interfaces.h b/libshortfin/src/shortfin/local/program_interfaces.h new file mode 100644 index 000000000..280b77506 --- /dev/null +++ b/libshortfin/src/shortfin/local/program_interfaces.h @@ -0,0 +1,85 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Standalone interfaces needed for marshaling as part of a ProgramInvocation.h. +// They are available in this dep-free header in order to ease the burden on +// types that would otherwise need to pull in all of the includes. + +#ifndef SHORTFIN_LOCAL_PROGRAM_INTERFACES_H +#define SHORTFIN_LOCAL_PROGRAM_INTERFACES_H + +#include "shortfin/support/api.h" +#include "shortfin/support/iree_helpers.h" + +namespace shortfin::local { + +class SHORTFIN_API ProgramInvocation; + +// The type of barrier that should be managed for a program resource. +enum class ProgramResourceBarrier { + // The caller has explicitly not stated a preference. + DEFAULT, + + // The argument will be used by the program for input and the program + // must not perform operations on it until all pending mutations have + // been completed. Concurrent reads/uses are permitted. + // This is the default concurrency in most situations. + READ, + + // The argument will be used for input/output and the program must not + // perform operations on it until all prior mutations and uses have been + // complete. + WRITE, + + // No concurrency barriers will be emplaced on behalf of the argument, + // explicitly allowing racy access. The program and the caller must + // ensure that only valid accesses are made. + NONE, +}; + +// Implemented by a class if it can marshal itself to an invocation as an +// argument. +class SHORTFIN_API ProgramInvocationMarshalable { + public: + // Adds this object as an invocation argument. + virtual void AddAsInvocationArgument(ProgramInvocation *inv, + ProgramResourceBarrier barrier) = 0; +}; + +// Trampoline class that has visibility into marshalable types and can be used +// to construct them from an invocation reference. +class SHORTFIN_API ProgramInvocationMarshalableFactory { + public: + // Instantiates a new `T` from an opaque reference retrieved from an + // invocation result. This will call through to a factory on the type to + // construct a new user-value and setup any needed barriers from the + // invocation. + // + // In order for a type to be eligible for such usage, it must expose a + // `T CreateFromInvocationResultRef(ProgramInvocation *inv, + // iree::vm_opaque_ref)` static method. The type `T` must be friends with this + // class. + template + static T CreateFromInvocationResultRef(ProgramInvocation *inv, + iree::vm_opaque_ref ref) { + return T::CreateFromInvocationResultRef(inv, std::move(ref)); + } + + // Gets the type id that corresponds to this marshalable type. + // + // Marshalable types should define the same method. + // It is recommended that these type methods are defined in shortfin + // implementation files (not headers) since that ensures that no cross-DSO + // symbol visibility issues can transpire. + template + static iree_vm_ref_type_t invocation_marshalable_type() { + return T::invocation_marshalable_type(); + } +}; + +} // namespace shortfin::local + +#endif // SHORTFIN_LOCAL_PROGRAM_INTERFACES_H diff --git a/libshortfin/src/shortfin/local/scheduler.cc b/libshortfin/src/shortfin/local/scheduler.cc index c5a9fc062..f0200ebdf 100644 --- a/libshortfin/src/shortfin/local/scheduler.cc +++ b/libshortfin/src/shortfin/local/scheduler.cc @@ -12,6 +12,26 @@ namespace shortfin::local::detail { +namespace { + +[[maybe_unused]] std::string SummarizeFence(iree_hal_fence_t *fence) { + if (!SHORTFIN_SCHED_LOG_ENABLED) { + return std::string(); + } + std::string result("fence("); + iree_hal_semaphore_list_t list = iree_hal_fence_semaphore_list(fence); + for (iree_host_size_t i = 0; i < list.count; ++i) { + if (i > 0) result.append(", "); + result.append(fmt::format("[{}@{}]", + static_cast(list.semaphores[i]), + list.payload_values[i])); + } + result.append(")"); + return result; +} + +} // namespace + // -------------------------------------------------------------------------- // // Account // -------------------------------------------------------------------------- // @@ -30,9 +50,6 @@ void Account::Initialize() { void Account::Reset() { active_tx_type_ = TransactionType::NONE; - // if (active_command_buffer_) { - // iree_hal_command_buffer_end(active_command_buffer_); - // } active_command_buffer_.reset(); } @@ -55,11 +72,15 @@ CompletionEvent Account::OnSync() { iree::shared_event::ref satisfied(false); iree::hal_semaphore_ptr sem = sem_; auto idle_timepoint = idle_timepoint_; + SHORTFIN_SCHED_LOG("OnSync::Wait({}@{})", static_cast(sem.get()), + idle_timepoint); scheduler_.system().blocking_executor().Schedule( [sem = std::move(sem), idle_timepoint, satisfied]() { iree_status_t status = iree_hal_semaphore_wait( sem, idle_timepoint, iree_infinite_timeout()); IREE_CHECK_OK(status); + SHORTFIN_SCHED_LOG("OnSync::Complete({}@{})", + static_cast(sem.get()), idle_timepoint); satisfied->set(); }); return CompletionEvent(satisfied); @@ -89,6 +110,10 @@ void TimelineResource::use_barrier_insert(iree_hal_semaphore_t *sem, iree_hal_fence_insert(use_barrier_fence_, sem, timepoint)); } +iree_allocator_t TimelineResource::host_allocator() { + return scope_->host_allocator(); +} + // -------------------------------------------------------------------------- // // Scheduler // -------------------------------------------------------------------------- // @@ -140,8 +165,10 @@ void Scheduler::AppendCommandBuffer(ScopedDevice &device, TransactionType tx_type, std::function callback) { Account &account = GetDefaultAccount(device); - auto needed_affinity_bits = device.affinity().queue_affinity(); + SHORTFIN_SCHED_LOG( + "AppendCommandBuffer(account=0x{:x}, tx_type={}, queue_affinity={}):", + account.id(), static_cast(tx_type), needed_affinity_bits); // Initialize a fresh command buffer if needed. if (!account.active_command_buffer_) { @@ -181,6 +208,11 @@ void Scheduler::AppendCommandBuffer(ScopedDevice &device, account.active_deps_ = std::move(new_active_deps); account.active_command_buffer_ = std::move(new_cb); account.idle_timepoint_ += 1; + SHORTFIN_SCHED_LOG( + " : New command buffer (category={}, idle_timepoint={})", category, + account.idle_timepoint_); + } else { + SHORTFIN_SCHED_LOG(" : Continue active command buffer"); } // Perform the mutation. @@ -192,21 +224,29 @@ void Scheduler::AppendCommandBuffer(ScopedDevice &device, } } -void Scheduler::Flush() { +iree_status_t Scheduler::FlushWithStatus() noexcept { // This loop is optimized for a small number of accounts, where it is // fine to just linearly probe. If this ever becomes cumbersome, we can // maintain a dirty list which is appended to when an account transitions // from idle to active. for (Account &account : accounts_) { if (!account.active_command_buffer_) continue; - iree_hal_semaphore_t *signal_sem = account.sem_; uint64_t signal_timepoint = account.idle_timepoint_; iree_hal_command_buffer_t *active_command_buffer = account.active_command_buffer_; iree_hal_buffer_binding_table_t binding_tables = iree_hal_buffer_binding_table_empty(); - SHORTFIN_THROW_IF_ERROR(iree_hal_device_queue_execute( + + SHORTFIN_SCHED_LOG( + "Flush command buffer (account=0x{:x}, queue_affinity={}, " + "signal_timepoint={}, deps={})", + account.id(), account.active_queue_affinity_bits_, signal_timepoint, + SummarizeFence(account.active_deps_)); + + // End recording and submit. + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_end(active_command_buffer)); + IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute( account.hal_device(), /*queue_affinity=*/account.active_queue_affinity_bits_, /*wait_sempahore_list=*/account.active_deps_ @@ -223,6 +263,14 @@ void Scheduler::Flush() { /*binding_tables=*/&binding_tables)); account.Reset(); } + return iree_ok_status(); +} + +iree::hal_fence_ptr Scheduler::NewFence() { + iree::hal_fence_ptr fence; + iree_hal_fence_create(semaphore_count_, system_.host_allocator(), + fence.for_output()); + return fence; } } // namespace shortfin::local::detail diff --git a/libshortfin/src/shortfin/local/scheduler.h b/libshortfin/src/shortfin/local/scheduler.h index 2f606ced3..680481e6b 100644 --- a/libshortfin/src/shortfin/local/scheduler.h +++ b/libshortfin/src/shortfin/local/scheduler.h @@ -142,6 +142,8 @@ class SHORTFIN_API TimelineResource { return iree_hal_fence_semaphore_list(use_barrier_fence_); } + iree_allocator_t host_allocator(); + private: TimelineResource(std::shared_ptr scope, size_t semaphore_capacity); ~TimelineResource(); @@ -174,7 +176,11 @@ class SHORTFIN_API Account { Account(Scheduler &scheduler, Device *device); Device *device() const { return device_; } iree_hal_device_t *hal_device() { return hal_device_; } + size_t semaphore_count() const { return 1; } + // Gets a unique integer id for this account. Currently just the address of + // the sem, but can be derived from any owned entity. + uintptr_t id() const { return reinterpret_cast(sem_.get()); } // Accesses the active command buffer. This will only be non-null if a // pending transaction has been set up (i.e. via AppendCommandBuffer). @@ -188,6 +194,7 @@ class SHORTFIN_API Account { // Queue timeline. iree_hal_semaphore_t *timeline_sem() { return sem_; } uint64_t timeline_idle_timepoint() { return idle_timepoint_; } + uint64_t timeline_acquire_timepoint() { return ++idle_timepoint_; } // Returns a future that is satisfied when the timeline of this account // reaches its current idle timepoint (i.e. all currently pending work @@ -248,7 +255,8 @@ class SHORTFIN_API Scheduler { std::function callback); // Flushes any pending accounts that have accumulated commands. - void Flush(); + iree_status_t FlushWithStatus() noexcept; + void Flush() { SHORTFIN_THROW_IF_ERROR(FlushWithStatus()); } // Gets a fresh TimelineResource which can be used for tracking resource // read/write and setting barriers. Note that these are all allocated fresh @@ -258,6 +266,10 @@ class SHORTFIN_API Scheduler { new TimelineResource(std::move(scope), semaphore_count_)); } + // Creates a new fence with capacity for all semaphores that are extant at + // the point of the call. + iree::hal_fence_ptr NewFence(); + System &system() { return system_; } private: diff --git a/libshortfin/src/shortfin/local/scope.cc b/libshortfin/src/shortfin/local/scope.cc index 39784f196..30211a052 100644 --- a/libshortfin/src/shortfin/local/scope.cc +++ b/libshortfin/src/shortfin/local/scope.cc @@ -9,7 +9,6 @@ #include #include -#include "iree/modules/hal/module.h" #include "shortfin/local/system.h" #include "shortfin/support/logging.h" @@ -90,48 +89,6 @@ std::vector Scope::device_names() const { return names; } -Program Scope::LoadUnboundProgram(std::span modules, - Program::Options options) { - std::vector all_modules; - std::vector raw_devices; - - // By default, bind all devices in the scope in order to the program. - for (Device *d : devices_) { - raw_devices.push_back(d->hal_device()); - } - - // Add a HAL module. - // TODO: at some point may want to change this to something similar to - // what the tooling does in iree_tooling_resolve_modules - it uses - // iree_vm_module_enumerate_dependencies to walk the dependencies and add the - // required modules only as needed. to start you could use it just to see if - // the hal is used, but as you add other module types for exposing sharkfin - // functionality (or module versions; iree_vm_module_dependency_t has the - // minimum version required so you can switch between them, and whether they - // are optional/required). - iree::vm_module_ptr hal_module; - SHORTFIN_THROW_IF_ERROR(iree_hal_module_create( - system().vm_instance(), raw_devices.size(), raw_devices.data(), - IREE_HAL_MODULE_FLAG_NONE, system().host_allocator(), - hal_module.for_output())); - all_modules.push_back(hal_module); - - // Add explicit modules. - for (auto &pm : modules) { - all_modules.push_back(pm.vm_module()); - } - - // Create the context. - iree::vm_context_ptr context; - iree_vm_context_flags_t flags = IREE_VM_CONTEXT_FLAG_CONCURRENT; - if (options.trace_execution) flags |= IREE_VM_CONTEXT_FLAG_TRACE_EXECUTION; - SHORTFIN_THROW_IF_ERROR(iree_vm_context_create_with_modules( - system().vm_instance(), flags, all_modules.size(), all_modules.data(), - system().host_allocator(), context.for_output())); - - return Program(std::move(context)); -} - // -------------------------------------------------------------------------- // // ScopedDevice // -------------------------------------------------------------------------- // diff --git a/libshortfin/src/shortfin/local/scope.h b/libshortfin/src/shortfin/local/scope.h index 0cb566b89..e02984f38 100644 --- a/libshortfin/src/shortfin/local/scope.h +++ b/libshortfin/src/shortfin/local/scope.h @@ -132,14 +132,6 @@ class SHORTFIN_API Scope : public std::enable_shared_from_this { return scheduler().NewTimelineResource(shared_ptr()); } - // Loads a program from a list of modules onto the devices managed by this - // scope. The resulting program is not bound to this scope and can be imported - // into compatible scopes for actual execution. - // TODO: This is temporary during API evolution: a higher level API that - // includes all module concepts, params, etc is needed. - Program LoadUnboundProgram(std::span modules, - Program::Options options = {}); - private: void AddDevice(std::string_view device_class, Device *device); void Initialize(); // Called after all devices are added. diff --git a/libshortfin/src/shortfin/local/worker.h b/libshortfin/src/shortfin/local/worker.h index 52f5e5948..10924ae3c 100644 --- a/libshortfin/src/shortfin/local/worker.h +++ b/libshortfin/src/shortfin/local/worker.h @@ -79,6 +79,7 @@ class SHORTFIN_API Worker { const Options &options() const { return options_; } const std::string_view name() const { return options_.name; } + iree_loop_t loop() { return loop_; } std::string to_s(); // Gets the Worker that is active for the current thread or nullptr if none. diff --git a/libshortfin/src/shortfin/support/iree_concurrency.h b/libshortfin/src/shortfin/support/iree_concurrency.h index 28ef1e99b..be3e42742 100644 --- a/libshortfin/src/shortfin/support/iree_concurrency.h +++ b/libshortfin/src/shortfin/support/iree_concurrency.h @@ -16,21 +16,7 @@ namespace shortfin::iree { -namespace detail { -struct thread_ptr_helper { - static void steal(iree_thread_t *obj) { LogIREESteal("iree_thread_t", obj); } - static void retain(iree_thread_t *obj) { - LogIREERetain("iree_thread_t", obj); - iree_thread_retain(obj); - } - static void release(iree_thread_t *obj) { - LogIREERelease("iree_thread_t", obj); - iree_thread_release(obj); - } -}; -}; // namespace detail - -using thread_ptr = object_ptr; +SHORTFIN_IREE_DEF_PTR(thread); // Wraps an iree::slim_mutex as an RAII object. class slim_mutex { diff --git a/libshortfin/src/shortfin/support/iree_helpers.h b/libshortfin/src/shortfin/support/iree_helpers.h index c77ddbaa8..3eee51b96 100644 --- a/libshortfin/src/shortfin/support/iree_helpers.h +++ b/libshortfin/src/shortfin/support/iree_helpers.h @@ -15,6 +15,7 @@ #include "iree/hal/api.h" #include "iree/modules/hal/types.h" #include "iree/vm/api.h" +#include "iree/vm/ref_cc.h" #include "shortfin/support/api.h" #if !defined(SHORTFIN_IREE_LOG_RC) @@ -32,6 +33,10 @@ inline std::string_view to_string_view(iree_string_view_t isv) { return std::string_view(isv.data, isv.size); } +inline iree_string_view_t to_iree_string_view(std::string_view sv) { + return iree_make_string_view(sv.data(), sv.size()); +} + namespace iree { // -------------------------------------------------------------------------- // @@ -52,132 +57,6 @@ inline void LogIREESteal(const char *type_name, void *ptr) {} inline void LogLiveRefs() {} #endif -struct hal_buffer_ptr_helper { - static void steal(iree_hal_buffer_t *obj) { - LogIREESteal("iree_hal_buffer_t", obj); - } - static void retain(iree_hal_buffer_t *obj) { - LogIREERetain("iree_hal_buffer_t", obj); - iree_hal_buffer_retain(obj); - } - static void release(iree_hal_buffer_t *obj) { - LogIREERelease("iree_hal_buffer_t", obj); - iree_hal_buffer_release(obj); - } -}; - -struct hal_command_buffer_helper { - static void steal(iree_hal_command_buffer_t *obj) { - LogIREESteal("iree_hal_command_buffer_t", obj); - } - static void retain(iree_hal_command_buffer_t *obj) { - LogIREERetain("iree_hal_command_buffer_t", obj); - iree_hal_command_buffer_retain(obj); - } - static void release(iree_hal_command_buffer_t *obj) { - LogIREERelease("iree_hal_command_buffer_t", obj); - iree_hal_command_buffer_release(obj); - } -}; - -struct hal_device_ptr_helper { - static void steal(iree_hal_device_t *obj) { - LogIREESteal("iree_hal_device_t", obj); - } - static void retain(iree_hal_device_t *obj) { - LogIREERetain("iree_hal_device_t", obj); - iree_hal_device_retain(obj); - } - static void release(iree_hal_device_t *obj) { - LogIREERelease("iree_hal_device_t", obj); - iree_hal_device_release(obj); - } -}; - -struct hal_driver_ptr_helper { - static void steal(iree_hal_driver_t *obj) { - LogIREESteal("iree_hal_driver_t", obj); - } - static void retain(iree_hal_driver_t *obj) { - LogIREERetain("iree_hal_driver_t", obj); - iree_hal_driver_retain(obj); - } - static void release(iree_hal_driver_t *obj) { - LogIREERelease("iree_hal_driver_t", obj); - iree_hal_driver_release(obj); - } -}; - -struct hal_fence_ptr_helper { - static void steal(iree_hal_fence_t *obj) { - LogIREESteal("iree_hal_fence_t", obj); - } - static void retain(iree_hal_fence_t *obj) { - LogIREERetain("iree_hal_fence_t", obj); - iree_hal_fence_retain(obj); - } - static void release(iree_hal_fence_t *obj) { - LogIREERelease("iree_hal_fence_t", obj); - iree_hal_fence_release(obj); - } -}; - -struct hal_semaphore_ptr_helper { - static void steal(iree_hal_semaphore_t *obj) { - LogIREESteal("iree_hal_semaphore_t", obj); - } - static void retain(iree_hal_semaphore_t *obj) { - LogIREERetain("iree_hal_semaphore_t", obj); - iree_hal_semaphore_retain(obj); - } - static void release(iree_hal_semaphore_t *obj) { - LogIREERelease("iree_hal_semaphore_t", obj); - iree_hal_semaphore_release(obj); - } -}; - -struct vm_context_ptr_helper { - static void steal(iree_vm_context_t *obj) { - LogIREESteal("iree_vm_context_t", obj); - } - static void retain(iree_vm_context_t *obj) { - LogIREERetain("iree_vm_context_t", obj); - iree_vm_context_retain(obj); - } - static void release(iree_vm_context_t *obj) { - LogIREERelease("iree_vm_context_t", obj); - iree_vm_context_release(obj); - } -}; - -struct vm_instance_ptr_helper { - static void steal(iree_vm_instance_t *obj) { - LogIREESteal("iree_vm_instance_t", obj); - } - static void retain(iree_vm_instance_t *obj) { - LogIREERetain("iree_vm_instance_t", obj); - iree_vm_instance_retain(obj); - } - static void release(iree_vm_instance_t *obj) { - LogIREERelease("iree_vm_instance_t", obj); - iree_vm_instance_release(obj); - } -}; - -struct vm_module_ptr_helper { - static void steal(iree_vm_module_t *obj) { - LogIREESteal("iree_vm_module_t", obj); - } - static void retain(iree_vm_module_t *obj) { - LogIREERetain("iree_vm_module_t", obj); - iree_vm_module_retain(obj); - } - static void release(iree_vm_module_t *obj) { - LogIREERelease("iree_vm_module_t", obj); - iree_vm_module_release(obj); - } -}; - }; // namespace detail // Wraps an IREE retain/release style object pointer in a smart-pointer @@ -261,24 +140,39 @@ class object_ptr { friend class Assignment; }; -using hal_buffer_ptr = - object_ptr; -using hal_command_buffer_ptr = - object_ptr; -using hal_driver_ptr = - object_ptr; -using hal_device_ptr = - object_ptr; -using hal_fence_ptr = - object_ptr; -using hal_semaphore_ptr = - object_ptr; -using vm_context_ptr = - object_ptr; -using vm_instance_ptr = - object_ptr; -using vm_module_ptr = - object_ptr; +// Defines a reference counting helper struct named like +// iree_hal_buffer_ptr_helper (for type_stem == hal_buffer). +// These must be defined in the shortfin::iree::detail namespace. +#define SHORTFIN_IREE_DEF_PTR(type_stem) \ + namespace detail { \ + struct type_stem##_ptr_helper { \ + static void steal(iree_##type_stem##_t *obj) { \ + LogIREESteal(#type_stem "_t", obj); \ + } \ + static void retain(iree_##type_stem##_t *obj) { \ + LogIREERetain(#type_stem "_t", obj); \ + iree_##type_stem##_retain(obj); \ + } \ + static void release(iree_##type_stem##_t *obj) { \ + LogIREERelease(#type_stem "_t", obj); \ + iree_##type_stem##_release(obj); \ + } \ + }; \ + } \ + using type_stem##_ptr = \ + object_ptr + +SHORTFIN_IREE_DEF_PTR(hal_command_buffer); +SHORTFIN_IREE_DEF_PTR(hal_buffer); +SHORTFIN_IREE_DEF_PTR(hal_buffer_view); +SHORTFIN_IREE_DEF_PTR(hal_device); +SHORTFIN_IREE_DEF_PTR(hal_driver); +SHORTFIN_IREE_DEF_PTR(hal_fence); +SHORTFIN_IREE_DEF_PTR(hal_semaphore); +SHORTFIN_IREE_DEF_PTR(vm_context); +SHORTFIN_IREE_DEF_PTR(vm_instance); +SHORTFIN_IREE_DEF_PTR(vm_list); +SHORTFIN_IREE_DEF_PTR(vm_module); // Holds a pointer allocated by some allocator, deleting it if still owned // at destruction time. @@ -432,13 +326,27 @@ class ignorable_status { ignorable_status(ignorable_status &&other) = delete; ~ignorable_status() { iree_status_ignore(status_); } - operator iree_status_t() const { return status_; } + // Consumes that status. Only the first consumer will receive all payloads. + // Others will just get the cloned basic status. + iree_status_t ConsumeStatus() { + iree_status_t local_status = status_; + status_ = iree_status_clone(status_); + return local_status; + } iree_status_t status() const { return status_; } private: - iree_status_t status_; + mutable iree_status_t status_; }; +// -------------------------------------------------------------------------- // +// VM Ref and Variant Interop +// -------------------------------------------------------------------------- // + +using vm_opaque_ref = ::iree::vm::opaque_ref; +template +using vm_ref = ::iree::vm::ref; + } // namespace iree } // namespace shortfin diff --git a/libshortfin/src/shortfin/support/logging.h b/libshortfin/src/shortfin/support/logging.h index 337ebacae..4929d49fc 100644 --- a/libshortfin/src/shortfin/support/logging.h +++ b/libshortfin/src/shortfin/support/logging.h @@ -13,6 +13,14 @@ #define SHORTFIN_LOG_LIFETIMES 0 #endif +// Scheduler logging. +#define SHORTFIN_SCHED_LOG_ENABLED 0 +#if SHORTFIN_SCHED_LOG_ENABLED +#define SHORTFIN_SCHED_LOG(...) shortfin::logging::info("SCHED: " __VA_ARGS__) +#else +#define SHORTFIN_SCHED_LOG(...) +#endif + namespace shortfin::logging { // TODO: Re-export doesn't really work like this. Need to define API