Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] [wip] Rework memory store signal checking in C++ instead of cython #49319

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,12 +1110,6 @@ def test_dag_exception_chained(ray_start_regular, capsys):
# Can use the DAG after exceptions are thrown.
assert ray.get(compiled_dag.execute(1)) == 2

# Note: somehow the auto triggered teardown() from ray.shutdown()
# does not finish in time for this test, leading to a segfault
# of the following test (likely due to a dangling monitor thread
# upon the new Ray init).
compiled_dag.teardown()


@pytest.mark.parametrize("single_fetch", [True, False])
def test_dag_exception_multi_output(ray_start_regular, single_fetch, capsys):
Expand Down
1 change: 0 additions & 1 deletion src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,6 @@ CoreWorker::CoreWorker(CoreWorkerOptions options, const WorkerID &worker_id)
io_service_,
reference_counter_.get(),
local_raylet_client_,
options_.check_signals,
[this](const RayObject &obj) {
rpc::ErrorType error_type;
if (obj.IsException(&error_type) &&
Expand Down
37 changes: 22 additions & 15 deletions src/ray/core_worker/store_provider/memory_store/memory_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <condition_variable>
#include <csignal>
#include <utility>

#include "ray/common/ray_config.h"
Expand All @@ -31,6 +32,14 @@ const int64_t kUnhandledErrorGracePeriodNanos = static_cast<int64_t>(5e9);
// when there are too many local objects.
const int kMaxUnhandledErrorScanItems = 1000;

namespace {

std::atomic<int> signal_received = -1;

void SignalHandler(int signal) { signal_received = signal; }

} // namespace

/// A class that represents a `Get` request.
class GetRequest {
public:
Expand Down Expand Up @@ -153,14 +162,12 @@ CoreWorkerMemoryStore::CoreWorkerMemoryStore(
instrumented_io_context &io_context,
ReferenceCounter *counter,
std::shared_ptr<raylet::RayletClient> raylet_client,
std::function<Status()> check_signals,
std::function<void(const RayObject &)> unhandled_exception_handler,
std::function<std::shared_ptr<ray::RayObject>(
const ray::RayObject &object, const ObjectID &object_id)> object_allocator)
: io_context_(io_context),
ref_counter_(counter),
raylet_client_(std::move(raylet_client)),
check_signals_(std::move(check_signals)),
unhandled_exception_handler_(std::move(unhandled_exception_handler)),
object_allocator_(std::move(object_allocator)) {}

Expand Down Expand Up @@ -366,7 +373,6 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector<ObjectID> &object_ids,

bool done = false;
bool timed_out = false;
Status signal_status = Status::OK();
int64_t remaining_timeout = timeout_ms;
int64_t iteration_timeout =
std::min(timeout_ms, RayConfig::instance().get_timeout_milliseconds());
Expand All @@ -379,16 +385,16 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector<ObjectID> &object_ids,
// calls. If timeout_ms == -1, this should run forever until all objects are
// ready or a signal is received. Else it should run repeatedly until that timeout
// is reached.
while (!timed_out && signal_status.ok() &&
!(done = get_request->Wait(iteration_timeout))) {
if (check_signals_) {
signal_status = check_signals_();
}

if (remaining_timeout >= 0) {
remaining_timeout -= iteration_timeout;
iteration_timeout = std::min(remaining_timeout, iteration_timeout);
timed_out = remaining_timeout <= 0;
{
std::signal(SIGINT, SignalHandler);
std::signal(SIGTERM, SignalHandler);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you want to overwrite default SIGTERM handler? Which means exception thrown will not have any effect

Copy link
Contributor Author

@dayshah dayshah Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to replicate the behavior of the check_signals mentioned in the pr description, I think sigterm maps to the SystemExit, will double check

Copy link
Contributor Author

@dayshah dayshah Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm possibly not, removed sigterm and only keeping sigint for now, we really just need to check for ctrl+c looking at the previous pr that introduced check_signals

while (!timed_out && signal_received == -1 &&
!(done = get_request->Wait(iteration_timeout))) {
if (remaining_timeout >= 0) {
remaining_timeout -= iteration_timeout;
iteration_timeout = std::min(remaining_timeout, iteration_timeout);
timed_out = remaining_timeout <= 0;
}
}
}

Expand Down Expand Up @@ -424,8 +430,9 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector<ObjectID> &object_ids,
}
}

if (!signal_status.ok()) {
return signal_status;
if (signal_received != -1) {
return Status::Interrupted("Interrupted by signal: " +
std::to_string(signal_received));
} else if (done) {
return Status::OK();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class CoreWorkerMemoryStore {
instrumented_io_context &io_context,
ReferenceCounter *counter = nullptr,
std::shared_ptr<raylet::RayletClient> raylet_client = nullptr,
std::function<Status()> check_signals = nullptr,
std::function<void(const RayObject &)> unhandled_exception_handler = nullptr,
std::function<std::shared_ptr<RayObject>(const RayObject &object,
const ObjectID &object_id)>
Expand Down Expand Up @@ -221,9 +220,6 @@ class CoreWorkerMemoryStore {
std::vector<std::function<void(std::shared_ptr<RayObject>)>>>
object_async_get_requests_ ABSL_GUARDED_BY(mu_);

/// Function passed in to be called to check for signals (e.g., Ctrl-C).
std::function<Status()> check_signals_;

/// Function called to report unhandled exceptions.
std::function<void(const RayObject &)> unhandled_exception_handler_;

Expand Down
9 changes: 3 additions & 6 deletions src/ray/core_worker/test/memory_store_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@ TEST(TestMemoryStore, TestReportUnhandledErrors) {

std::shared_ptr<CoreWorkerMemoryStore> provider =
std::make_shared<CoreWorkerMemoryStore>(
io_context.GetIoService(),
nullptr,
nullptr,
nullptr,
[&](const RayObject &obj) { unhandled_count++; });
io_context.GetIoService(), nullptr, nullptr, [&](const RayObject &obj) {
unhandled_count++;
});
RayObject obj1(rpc::ErrorType::TASK_EXECUTION_EXCEPTION);
RayObject obj2(rpc::ErrorType::TASK_EXECUTION_EXCEPTION);
auto id1 = ObjectID::FromRandom();
Expand Down Expand Up @@ -203,7 +201,6 @@ TEST(TestMemoryStore, TestObjectAllocator) {
nullptr,
nullptr,
nullptr,
nullptr,
std::move(my_object_allocator));
const int32_t max_rounds = 1000;
const std::string hello = "hello";
Expand Down
Loading