From 023d31fccd1634752a9bcaa50cf6f2c2074d0441 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 29 Oct 2024 08:05:25 -0700 Subject: [PATCH] [shortfin] Finishes plumbing ProgramIsolation through to the Python API. (#350) * Fixes an egregious bug on wait fence setup (that was causing invocations to always have an empty wait set) discovered while testing these more advanced modes. * Adds `isolation=` to `Program` constructor and invocation APIs (for fine-grained control). * Adds `ProgramIsolation` enum. * Adds `Program.isolation` property. * Adds `ProgramFunction.isolation` property. * Overhauls the mobilenet invocation test to exercise several newly available concurrency variants. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> --- shortfin/python/lib_ext.cc | 47 ++++--- shortfin/python/shortfin/__init__.py | 1 + shortfin/src/shortfin/local/program.cc | 18 +-- shortfin/src/shortfin/local/program.h | 10 +- .../invocation/mobilenet_program_test.py | 131 +++++++++++++++++- 5 files changed, 177 insertions(+), 30 deletions(-) diff --git a/shortfin/python/lib_ext.cc b/shortfin/python/lib_ext.cc index c73bf5a93..f04e29bba 100644 --- a/shortfin/python/lib_ext.cc +++ b/shortfin/python/lib_ext.cc @@ -247,10 +247,10 @@ void PyAddProgramInvocationArg(py::capsule &inv_capsule, py::handle arg) { py::cast(py::repr(arg.type())))); } -local::ProgramInvocation::Future PyFunctionCall(local::ProgramFunction &self, - py::args args, - local::Fiber &fiber) { - auto inv = self.CreateInvocation(fiber.shared_from_this()); +local::ProgramInvocation::Future PyFunctionCall( + local::ProgramFunction &self, py::args args, local::Fiber &fiber, + std::optional isolation) { + auto inv = self.CreateInvocation(fiber.shared_from_this(), isolation); py::capsule inv_capsule(inv.get()); for (py::handle arg : args) { PyAddProgramInvocationArg(inv_capsule, arg); @@ -446,6 +446,12 @@ void BindLocal(py::module_ &m) { std::make_unique(worker, interp_state, refs)); }; + py::enum_(m, "ProgramIsolation") + .value("NONE", local::ProgramIsolation::NONE) + .value("PER_FIBER", local::ProgramIsolation::PER_FIBER) + .value("PER_CALL", local::ProgramIsolation::PER_CALL) + .export_values(); + py::class_(m, "SystemBuilder") .def("create_system", [live_system_refs, worker_initializer](local::SystemBuilder &self) { @@ -592,17 +598,21 @@ void BindLocal(py::module_ &m) { .def("__repr__", &local::DeviceAffinity::to_s); py::class_(m, "Program") - .def(py::new_([](std::span modules, - std::vector devices, - bool trace_execution) { - local::Program::Options options; - options.devices = devices; - options.trace_execution = trace_execution; - return local::Program::Load(modules, std::move(options)); - }), - py::arg("modules"), py::kw_only(), py::arg("devices"), - py::arg("trace_execution") = false) + .def( + py::new_([](std::span modules, + std::vector devices, + bool trace_execution, local::ProgramIsolation isolation) { + local::Program::Options options; + options.devices = devices; + options.trace_execution = trace_execution; + options.isolation = isolation; + return local::Program::Load(modules, std::move(options)); + }), + py::arg("modules"), py::kw_only(), py::arg("devices"), + py::arg("trace_execution") = false, + py::arg("isolation") = local::ProgramIsolation::PER_FIBER) .def_prop_ro("exports", &local::Program::exports) + .def_prop_ro("isolation", &local::Program::isolation) .def("lookup_function", &local::Program::LookupRequiredFunction) .def("__getitem__", &local::Program::LookupRequiredFunction); py::class_(m, "ProgramFunction") @@ -611,12 +621,15 @@ void BindLocal(py::module_ &m) { &local::ProgramFunction::calling_convention) .def( "invocation", - [](local::ProgramFunction &self, local::Fiber &fiber) { - return self.CreateInvocation(fiber.shared_from_this()); + [](local::ProgramFunction &self, local::Fiber &fiber, + std::optional isolation) { + return self.CreateInvocation(fiber.shared_from_this(), isolation); }, + py::arg("fiber"), py::arg("isolation") = py::none(), DOCSTRING_PROGRAM_FUNCTION_INVOCATION) + .def_prop_ro("isolation", &local::ProgramFunction::isolation) .def("__call__", PyFunctionCall, py::arg("args"), py::kw_only(), - py::arg("fiber")) + py::arg("fiber"), py::arg("isolation") = py::none()) .def("__repr__", &local::ProgramFunction::to_s); py::class_(m, "ProgramModule") .def_prop_ro("exports", &local::ProgramModule::exports) diff --git a/shortfin/python/shortfin/__init__.py b/shortfin/python/shortfin/__init__.py index e1448c9da..c91058d62 100644 --- a/shortfin/python/shortfin/__init__.py +++ b/shortfin/python/shortfin/__init__.py @@ -20,6 +20,7 @@ Process = _sfl.local.Process Program = _sfl.local.Program ProgramFunction = _sfl.local.ProgramFunction +ProgramIsolation = _sfl.local.ProgramIsolation ProgramInvocation = _sfl.local.ProgramInvocation ProgramInvocationFuture = _sfl.local.ProgramInvocationFuture ProgramModule = _sfl.local.ProgramModule diff --git a/shortfin/src/shortfin/local/program.cc b/shortfin/src/shortfin/local/program.cc index 038cd106a..2af4e7a52 100644 --- a/shortfin/src/shortfin/local/program.cc +++ b/shortfin/src/shortfin/local/program.cc @@ -74,16 +74,17 @@ std::string_view ProgramFunction::calling_convention() const { } ProgramInvocation::Ptr ProgramFunction::CreateInvocation( - std::shared_ptr fiber) { + std::shared_ptr fiber, std::optional isolation) { + ProgramIsolation actual_isolation = isolation ? *isolation : isolation_; // Low-overhead NONE isolation handling (saves some ref-count twiddling). - if (isolation_ == ProgramIsolation::NONE) { + if (actual_isolation == ProgramIsolation::NONE) { return ProgramInvocation::New(std::move(fiber), vm_context_, vm_function_, invocation_model_, /*isolate=*/nullptr); } // Create an isolated invocation. - auto [isolated_context, isolate] = - detail::ProgramIsolate::AcquireIsolate(*fiber, vm_context_, isolation_); + auto [isolated_context, isolate] = detail::ProgramIsolate::AcquireIsolate( + *fiber, vm_context_, actual_isolation); return ProgramInvocation::New(std::move(fiber), std::move(isolated_context), vm_function_, invocation_model_, isolate); } @@ -403,26 +404,27 @@ iree_status_t ProgramInvocation::FinalizeCallingConvention( // Handle post-processing invocation model setup. if (invocation_model == ProgramInvocationModel::COARSE_FENCES) { // If we have a device_selection, set up to signal the leader account. + iree_hal_fence_t *maybe_wait_fence = nullptr; if (device_selection_) { ScopedDevice scoped_device(*fiber(), device_selection_); auto &sched_account = fiber()->scheduler().GetDefaultAccount(scoped_device); - iree_hal_fence_t *wait_fence = this->wait_fence(); + maybe_wait_fence = this->wait_fence(); iree_hal_semaphore_t *timeline_sem = sched_account.timeline_sem(); uint64_t timeline_now = sched_account.timeline_idle_timepoint(); SHORTFIN_SCHED_LOG("Invocation {}: Wait on account timeline {}@{}", static_cast(this), static_cast(timeline_sem), timeline_now); IREE_RETURN_IF_ERROR( - iree_hal_fence_insert(wait_fence, timeline_sem, timeline_now)); + iree_hal_fence_insert(maybe_wait_fence, timeline_sem, timeline_now)); signal_sem_ = sched_account.timeline_sem(); signal_timepoint_ = sched_account.timeline_acquire_timepoint(); } // Push wait fence (or null if no wait needed). ::iree::vm::ref wait_ref; - if (wait_fence_) { - ::iree::vm::retain_ref(wait_fence()); + if (maybe_wait_fence) { + wait_ref = ::iree::vm::retain_ref(maybe_wait_fence); } IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(arg_list, wait_ref)); diff --git a/shortfin/src/shortfin/local/program.h b/shortfin/src/shortfin/local/program.h index ea4f0cc3f..450b29736 100644 --- a/shortfin/src/shortfin/local/program.h +++ b/shortfin/src/shortfin/local/program.h @@ -214,8 +214,11 @@ class SHORTFIN_API ProgramFunction { std::string_view name() const; std::string_view calling_convention() const; ProgramInvocationModel invocation_model() const { return invocation_model_; } - - ProgramInvocation::Ptr CreateInvocation(std::shared_ptr fiber); + // Gets the default isolation level for this function. + ProgramIsolation isolation() const { return isolation_; } + ProgramInvocation::Ptr CreateInvocation( + std::shared_ptr fiber, + std::optional isolation = std::nullopt); std::string to_s() const; @@ -324,6 +327,9 @@ class SHORTFIN_API Program { // Gets the name of all exported functions. std::vector exports() const; + // Gets the default isolation level for all functions in this program. + ProgramIsolation isolation() const { return isolation_; } + // Eagerly does any per-fiber isolation preparation for the program at a // convenient point (usually init time) to avoid first-invocation overhead. void PrepareIsolate(Fiber &fiber); diff --git a/shortfin/tests/invocation/mobilenet_program_test.py b/shortfin/tests/invocation/mobilenet_program_test.py index 84903fb8f..a0f209219 100644 --- a/shortfin/tests/invocation/mobilenet_program_test.py +++ b/shortfin/tests/invocation/mobilenet_program_test.py @@ -42,6 +42,18 @@ def mobilenet_program_function( return main_function +@pytest.fixture +def mobilenet_program_function_per_call( + lsys, mobilenet_compiled_cpu_path +) -> tuple[sf.ProgramFunction]: + program_module = lsys.load_module(mobilenet_compiled_cpu_path) + program = sf.Program( + [program_module], devices=lsys.devices, isolation=sf.ProgramIsolation.PER_CALL + ) + main_function = program["module.torch-jit-export"] + return main_function + + def get_mobilenet_ref_input(device) -> sfnp.device_array: dummy_data = array.array( "f", ([0.2] * (224 * 224)) + ([0.4] * (224 * 224)) + ([-0.2] * (224 * 224)) @@ -62,11 +74,12 @@ async def assert_mobilenet_ref_output(device, device_output): absmean = functools.reduce( lambda x, y: x + abs(y) / len(flat_output), flat_output, 0.0 ) - print("RESULT:", absmean) assert absmean == pytest.approx(5.01964943873882) -def test_invoke_mobilenet(lsys, fiber0, mobilenet_program_function): +# Tests that a single invocation on a single fiber works. +def test_invoke_mobilenet_single_per_fiber(lsys, fiber0, mobilenet_program_function): + assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER device = fiber0.device(0) async def main(): @@ -77,7 +90,119 @@ async def main(): lsys.run(main()) -def test_invoke_mobilenet_multi_fiber(lsys, mobilenet_program_function): +# Tests that a single invocation on a single fiber in per_call mode works. +def test_invoke_mobilenet_single_per_call( + lsys, fiber0, mobilenet_program_function_per_call +): + assert mobilenet_program_function_per_call.isolation == sf.ProgramIsolation.PER_CALL + device = fiber0.device(0) + + async def main(): + device_input = get_mobilenet_ref_input(device) + (device_output,) = await mobilenet_program_function_per_call( + device_input, fiber=fiber0 + ) + await assert_mobilenet_ref_output(device, device_output) + + lsys.run(main()) + + +# Tests that chained back to back invocations on the same fiber work correctly. +# Does an async gather/assert with all results at the end. +def test_invoke_mobilenet_chained_per_fiber(lsys, fiber0, mobilenet_program_function): + assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER + device = fiber0.device(0) + + async def main(): + device_input = get_mobilenet_ref_input(device) + results = [ + await mobilenet_program_function(device_input, fiber=fiber0) + for _ in range(5) + ] + + await asyncio.gather( + *[ + assert_mobilenet_ref_output(device, device_output) + for (device_output,) in results + ] + ) + + lsys.run(main()) + + +# Tests that parallel invocations on a single fiber with a program in PER_CALL +# isolation functions properly. Note that in this variant, the await is done +# on all invocations vs serially per invocation (as in +# test_invoke_mobilenet_chained_per_fiber). This would be illegal if done on the +# same fiber without PER_CALL isolation managing forks. +# +# Note that since these are all operating on the same fiber, they are added to +# the device-side work graph with a one-after-the-other dependency, but the +# host side schedules concurrently. +def test_invoke_mobilenet_parallel_per_call( + lsys, fiber0, mobilenet_program_function_per_call +): + assert mobilenet_program_function_per_call.isolation == sf.ProgramIsolation.PER_CALL + device = fiber0.device(0) + + async def main(): + device_input = get_mobilenet_ref_input(device) + results = await asyncio.gather( + *[ + mobilenet_program_function_per_call(device_input, fiber=fiber0) + for _ in range(5) + ] + ) + + await asyncio.gather( + *[ + assert_mobilenet_ref_output(device, device_output) + for (device_output,) in results + ] + ) + + lsys.run(main()) + + +# Same as above but uses explicit isolation controls on the function vs as the +# program level. If this constraint were violated, shortfin makes a best effort +# attempt to detect the situation and raise an exception, but there are a subset +# of programs which are purely async and would make detection of this exception +# lossy in the synchronous completion case. +def test_invoke_mobilenet_parallel_per_call_explicit( + lsys, fiber0, mobilenet_program_function +): + assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER + device = fiber0.device(0) + + async def main(): + device_input = get_mobilenet_ref_input(device) + results = await asyncio.gather( + *[ + mobilenet_program_function( + device_input, fiber=fiber0, isolation=sf.ProgramIsolation.PER_CALL + ) + for _ in range(50) + ] + ) + + await asyncio.gather( + *[ + assert_mobilenet_ref_output(device, device_output) + for (device_output,) in results + ] + ) + + lsys.run(main()) + + +# Tests that independent executions on multiple fibers all run concurrently. +# All fibers share the same host thread but schedule concurrently. Since +# each fiber has its own timeline, device side graphs have no dependency on +# each other and also schedule concurrently. +def test_invoke_mobilenet_multi_fiber_per_fiber(lsys, mobilenet_program_function): + assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER + class InferProcess(sf.Process): async def run(self): start_time = time.time()