Skip to content

Commit

Permalink
[shortfin] Finishes plumbing ProgramIsolation through to the Python A…
Browse files Browse the repository at this point in the history
…PI. (#350)

* Fixes an egregious bug on wait fence setup (that was causing
invocations to always have an empty wait set) discovered while testing
these more advanced modes.
* Adds `isolation=` to `Program` constructor and invocation APIs (for
fine-grained control).
* Adds `ProgramIsolation` enum.
* Adds `Program.isolation` property.
* Adds `ProgramFunction.isolation` property.
* Overhauls the mobilenet invocation test to exercise several newly
available concurrency variants.

Co-authored-by: Ean Garvey <[email protected]>
  • Loading branch information
stellaraccident and monorimet authored Oct 29, 2024
1 parent e465c83 commit 023d31f
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 30 deletions.
47 changes: 30 additions & 17 deletions shortfin/python/lib_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ void PyAddProgramInvocationArg(py::capsule &inv_capsule, py::handle arg) {
py::cast<std::string>(py::repr(arg.type()))));
}

local::ProgramInvocation::Future PyFunctionCall(local::ProgramFunction &self,
py::args args,
local::Fiber &fiber) {
auto inv = self.CreateInvocation(fiber.shared_from_this());
local::ProgramInvocation::Future PyFunctionCall(
local::ProgramFunction &self, py::args args, local::Fiber &fiber,
std::optional<local::ProgramIsolation> isolation) {
auto inv = self.CreateInvocation(fiber.shared_from_this(), isolation);
py::capsule inv_capsule(inv.get());
for (py::handle arg : args) {
PyAddProgramInvocationArg(inv_capsule, arg);
Expand Down Expand Up @@ -446,6 +446,12 @@ void BindLocal(py::module_ &m) {
std::make_unique<PyWorkerExtension>(worker, interp_state, refs));
};

py::enum_<local::ProgramIsolation>(m, "ProgramIsolation")
.value("NONE", local::ProgramIsolation::NONE)
.value("PER_FIBER", local::ProgramIsolation::PER_FIBER)
.value("PER_CALL", local::ProgramIsolation::PER_CALL)
.export_values();

py::class_<local::SystemBuilder>(m, "SystemBuilder")
.def("create_system", [live_system_refs,
worker_initializer](local::SystemBuilder &self) {
Expand Down Expand Up @@ -592,17 +598,21 @@ void BindLocal(py::module_ &m) {
.def("__repr__", &local::DeviceAffinity::to_s);

py::class_<local::Program>(m, "Program")
.def(py::new_([](std::span<const local::ProgramModule> modules,
std::vector<const local::Device *> devices,
bool trace_execution) {
local::Program::Options options;
options.devices = devices;
options.trace_execution = trace_execution;
return local::Program::Load(modules, std::move(options));
}),
py::arg("modules"), py::kw_only(), py::arg("devices"),
py::arg("trace_execution") = false)
.def(
py::new_([](std::span<const local::ProgramModule> modules,
std::vector<const local::Device *> devices,
bool trace_execution, local::ProgramIsolation isolation) {
local::Program::Options options;
options.devices = devices;
options.trace_execution = trace_execution;
options.isolation = isolation;
return local::Program::Load(modules, std::move(options));
}),
py::arg("modules"), py::kw_only(), py::arg("devices"),
py::arg("trace_execution") = false,
py::arg("isolation") = local::ProgramIsolation::PER_FIBER)
.def_prop_ro("exports", &local::Program::exports)
.def_prop_ro("isolation", &local::Program::isolation)
.def("lookup_function", &local::Program::LookupRequiredFunction)
.def("__getitem__", &local::Program::LookupRequiredFunction);
py::class_<local::ProgramFunction>(m, "ProgramFunction")
Expand All @@ -611,12 +621,15 @@ void BindLocal(py::module_ &m) {
&local::ProgramFunction::calling_convention)
.def(
"invocation",
[](local::ProgramFunction &self, local::Fiber &fiber) {
return self.CreateInvocation(fiber.shared_from_this());
[](local::ProgramFunction &self, local::Fiber &fiber,
std::optional<local::ProgramIsolation> isolation) {
return self.CreateInvocation(fiber.shared_from_this(), isolation);
},
py::arg("fiber"), py::arg("isolation") = py::none(),
DOCSTRING_PROGRAM_FUNCTION_INVOCATION)
.def_prop_ro("isolation", &local::ProgramFunction::isolation)
.def("__call__", PyFunctionCall, py::arg("args"), py::kw_only(),
py::arg("fiber"))
py::arg("fiber"), py::arg("isolation") = py::none())
.def("__repr__", &local::ProgramFunction::to_s);
py::class_<local::ProgramModule>(m, "ProgramModule")
.def_prop_ro("exports", &local::ProgramModule::exports)
Expand Down
1 change: 1 addition & 0 deletions shortfin/python/shortfin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Process = _sfl.local.Process
Program = _sfl.local.Program
ProgramFunction = _sfl.local.ProgramFunction
ProgramIsolation = _sfl.local.ProgramIsolation
ProgramInvocation = _sfl.local.ProgramInvocation
ProgramInvocationFuture = _sfl.local.ProgramInvocationFuture
ProgramModule = _sfl.local.ProgramModule
Expand Down
18 changes: 10 additions & 8 deletions shortfin/src/shortfin/local/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,17 @@ std::string_view ProgramFunction::calling_convention() const {
}

ProgramInvocation::Ptr ProgramFunction::CreateInvocation(
std::shared_ptr<Fiber> fiber) {
std::shared_ptr<Fiber> fiber, std::optional<ProgramIsolation> isolation) {
ProgramIsolation actual_isolation = isolation ? *isolation : isolation_;
// Low-overhead NONE isolation handling (saves some ref-count twiddling).
if (isolation_ == ProgramIsolation::NONE) {
if (actual_isolation == ProgramIsolation::NONE) {
return ProgramInvocation::New(std::move(fiber), vm_context_, vm_function_,
invocation_model_, /*isolate=*/nullptr);
}

// Create an isolated invocation.
auto [isolated_context, isolate] =
detail::ProgramIsolate::AcquireIsolate(*fiber, vm_context_, isolation_);
auto [isolated_context, isolate] = detail::ProgramIsolate::AcquireIsolate(
*fiber, vm_context_, actual_isolation);
return ProgramInvocation::New(std::move(fiber), std::move(isolated_context),
vm_function_, invocation_model_, isolate);
}
Expand Down Expand Up @@ -403,26 +404,27 @@ iree_status_t ProgramInvocation::FinalizeCallingConvention(
// Handle post-processing invocation model setup.
if (invocation_model == ProgramInvocationModel::COARSE_FENCES) {
// If we have a device_selection, set up to signal the leader account.
iree_hal_fence_t *maybe_wait_fence = nullptr;
if (device_selection_) {
ScopedDevice scoped_device(*fiber(), device_selection_);
auto &sched_account =
fiber()->scheduler().GetDefaultAccount(scoped_device);
iree_hal_fence_t *wait_fence = this->wait_fence();
maybe_wait_fence = this->wait_fence();
iree_hal_semaphore_t *timeline_sem = sched_account.timeline_sem();
uint64_t timeline_now = sched_account.timeline_idle_timepoint();
SHORTFIN_SCHED_LOG("Invocation {}: Wait on account timeline {}@{}",
static_cast<void *>(this),
static_cast<void *>(timeline_sem), timeline_now);
IREE_RETURN_IF_ERROR(
iree_hal_fence_insert(wait_fence, timeline_sem, timeline_now));
iree_hal_fence_insert(maybe_wait_fence, timeline_sem, timeline_now));
signal_sem_ = sched_account.timeline_sem();
signal_timepoint_ = sched_account.timeline_acquire_timepoint();
}

// Push wait fence (or null if no wait needed).
::iree::vm::ref<iree_hal_fence_t> wait_ref;
if (wait_fence_) {
::iree::vm::retain_ref(wait_fence());
if (maybe_wait_fence) {
wait_ref = ::iree::vm::retain_ref(maybe_wait_fence);
}
IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(arg_list, wait_ref));

Expand Down
10 changes: 8 additions & 2 deletions shortfin/src/shortfin/local/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,11 @@ class SHORTFIN_API ProgramFunction {
std::string_view name() const;
std::string_view calling_convention() const;
ProgramInvocationModel invocation_model() const { return invocation_model_; }

ProgramInvocation::Ptr CreateInvocation(std::shared_ptr<Fiber> fiber);
// Gets the default isolation level for this function.
ProgramIsolation isolation() const { return isolation_; }
ProgramInvocation::Ptr CreateInvocation(
std::shared_ptr<Fiber> fiber,
std::optional<ProgramIsolation> isolation = std::nullopt);

std::string to_s() const;

Expand Down Expand Up @@ -324,6 +327,9 @@ class SHORTFIN_API Program {
// Gets the name of all exported functions.
std::vector<std::string> exports() const;

// Gets the default isolation level for all functions in this program.
ProgramIsolation isolation() const { return isolation_; }

// Eagerly does any per-fiber isolation preparation for the program at a
// convenient point (usually init time) to avoid first-invocation overhead.
void PrepareIsolate(Fiber &fiber);
Expand Down
131 changes: 128 additions & 3 deletions shortfin/tests/invocation/mobilenet_program_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ def mobilenet_program_function(
return main_function


@pytest.fixture
def mobilenet_program_function_per_call(
lsys, mobilenet_compiled_cpu_path
) -> tuple[sf.ProgramFunction]:
program_module = lsys.load_module(mobilenet_compiled_cpu_path)
program = sf.Program(
[program_module], devices=lsys.devices, isolation=sf.ProgramIsolation.PER_CALL
)
main_function = program["module.torch-jit-export"]
return main_function


def get_mobilenet_ref_input(device) -> sfnp.device_array:
dummy_data = array.array(
"f", ([0.2] * (224 * 224)) + ([0.4] * (224 * 224)) + ([-0.2] * (224 * 224))
Expand All @@ -62,11 +74,12 @@ async def assert_mobilenet_ref_output(device, device_output):
absmean = functools.reduce(
lambda x, y: x + abs(y) / len(flat_output), flat_output, 0.0
)
print("RESULT:", absmean)
assert absmean == pytest.approx(5.01964943873882)


def test_invoke_mobilenet(lsys, fiber0, mobilenet_program_function):
# Tests that a single invocation on a single fiber works.
def test_invoke_mobilenet_single_per_fiber(lsys, fiber0, mobilenet_program_function):
assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER
device = fiber0.device(0)

async def main():
Expand All @@ -77,7 +90,119 @@ async def main():
lsys.run(main())


def test_invoke_mobilenet_multi_fiber(lsys, mobilenet_program_function):
# Tests that a single invocation on a single fiber in per_call mode works.
def test_invoke_mobilenet_single_per_call(
lsys, fiber0, mobilenet_program_function_per_call
):
assert mobilenet_program_function_per_call.isolation == sf.ProgramIsolation.PER_CALL
device = fiber0.device(0)

async def main():
device_input = get_mobilenet_ref_input(device)
(device_output,) = await mobilenet_program_function_per_call(
device_input, fiber=fiber0
)
await assert_mobilenet_ref_output(device, device_output)

lsys.run(main())


# Tests that chained back to back invocations on the same fiber work correctly.
# Does an async gather/assert with all results at the end.
def test_invoke_mobilenet_chained_per_fiber(lsys, fiber0, mobilenet_program_function):
assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER
device = fiber0.device(0)

async def main():
device_input = get_mobilenet_ref_input(device)
results = [
await mobilenet_program_function(device_input, fiber=fiber0)
for _ in range(5)
]

await asyncio.gather(
*[
assert_mobilenet_ref_output(device, device_output)
for (device_output,) in results
]
)

lsys.run(main())


# Tests that parallel invocations on a single fiber with a program in PER_CALL
# isolation functions properly. Note that in this variant, the await is done
# on all invocations vs serially per invocation (as in
# test_invoke_mobilenet_chained_per_fiber). This would be illegal if done on the
# same fiber without PER_CALL isolation managing forks.
#
# Note that since these are all operating on the same fiber, they are added to
# the device-side work graph with a one-after-the-other dependency, but the
# host side schedules concurrently.
def test_invoke_mobilenet_parallel_per_call(
lsys, fiber0, mobilenet_program_function_per_call
):
assert mobilenet_program_function_per_call.isolation == sf.ProgramIsolation.PER_CALL
device = fiber0.device(0)

async def main():
device_input = get_mobilenet_ref_input(device)
results = await asyncio.gather(
*[
mobilenet_program_function_per_call(device_input, fiber=fiber0)
for _ in range(5)
]
)

await asyncio.gather(
*[
assert_mobilenet_ref_output(device, device_output)
for (device_output,) in results
]
)

lsys.run(main())


# Same as above but uses explicit isolation controls on the function vs as the
# program level. If this constraint were violated, shortfin makes a best effort
# attempt to detect the situation and raise an exception, but there are a subset
# of programs which are purely async and would make detection of this exception
# lossy in the synchronous completion case.
def test_invoke_mobilenet_parallel_per_call_explicit(
lsys, fiber0, mobilenet_program_function
):
assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER
device = fiber0.device(0)

async def main():
device_input = get_mobilenet_ref_input(device)
results = await asyncio.gather(
*[
mobilenet_program_function(
device_input, fiber=fiber0, isolation=sf.ProgramIsolation.PER_CALL
)
for _ in range(50)
]
)

await asyncio.gather(
*[
assert_mobilenet_ref_output(device, device_output)
for (device_output,) in results
]
)

lsys.run(main())


# Tests that independent executions on multiple fibers all run concurrently.
# All fibers share the same host thread but schedule concurrently. Since
# each fiber has its own timeline, device side graphs have no dependency on
# each other and also schedule concurrently.
def test_invoke_mobilenet_multi_fiber_per_fiber(lsys, mobilenet_program_function):
assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER

class InferProcess(sf.Process):
async def run(self):
start_time = time.time()
Expand Down

0 comments on commit 023d31f

Please sign in to comment.