Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Mutex in two places #50

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions async_grpc/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@
#include "async_grpc/retry.h"
#include "async_grpc/rpc_handler_interface.h"
#include "async_grpc/rpc_service_method_traits.h"

#include "glog/logging.h"
#include "grpc++/grpc++.h"
#include "grpc++/impl/codegen/client_unary_call.h"
#include "grpc++/impl/codegen/proto_utils.h"
#include "grpc++/impl/codegen/sync_stream.h"

#include "glog/logging.h"

namespace async_grpc {

// Wraps a method invocation for all rpc types, unary, client streaming,
Expand Down
6 changes: 3 additions & 3 deletions async_grpc/completion_queue_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
* limitations under the License.
*/

#include "async_grpc/completion_queue_pool.h"

#include <cstdlib>

#include "async_grpc/async_client.h"
#include "async_grpc/completion_queue_pool.h"
#include "common/make_unique.h"
#include "glog/logging.h"

Expand Down Expand Up @@ -89,8 +90,7 @@ void CompletionQueuePool::Shutdown() {
}

CompletionQueuePool::CompletionQueuePool()
: number_completion_queues_(kDefaultNumberCompletionQueues) {
}
: number_completion_queues_(kDefaultNumberCompletionQueues) {}

CompletionQueuePool::~CompletionQueuePool() {
LOG(INFO) << "~CompletionQueuePool";
Expand Down
1 change: 1 addition & 0 deletions async_grpc/completion_queue_thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define CPP_GRPC_COMMON_COMPLETION_QUEUE_THREAD_H_

#include <grpc++/grpc++.h>

#include <memory>
#include <thread>

Expand Down
4 changes: 4 additions & 0 deletions async_grpc/event_queue_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ EventQueueThread::EventQueueThread() {

EventQueue* EventQueueThread::event_queue() { return event_queue_.get(); }

const EventQueue* EventQueueThread::event_queue() const {
return event_queue_.get();
}

void EventQueueThread::Start(EventQueueRunner runner) {
CHECK(!thread_);
EventQueue* event_queue = event_queue_.get();
Expand Down
1 change: 1 addition & 0 deletions async_grpc/event_queue_thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class EventQueueThread {
EventQueueThread();

EventQueue* event_queue();
const EventQueue* event_queue() const;

void Start(EventQueueRunner runner);
void Shutdown();
Expand Down
3 changes: 2 additions & 1 deletion async_grpc/opencensus_span.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ void OpencensusSpan::End() { span_.End(); }

OpencensusSpan::OpencensusSpan(const std::string& name,
const OpencensusSpan* parent)
: span_(opencensus::trace::Span::StartSpan(name, parent ? &parent->span_: nullptr)) {}
: span_(opencensus::trace::Span::StartSpan(
name, parent ? &parent->span_ : nullptr)) {}

} // namespace async_grpc

Expand Down
5 changes: 3 additions & 2 deletions async_grpc/retry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@
* limitations under the License.
*/

#include "async_grpc/retry.h"

#include <chrono>
#include <cmath>
#include <thread>

#include "async_grpc/retry.h"
#include "glog/logging.h"

namespace async_grpc {

RetryStrategy CreateRetryStrategy(RetryIndicator retry_indicator,
RetryDelayCalculator retry_delay_calculator) {
return [retry_indicator, retry_delay_calculator](
int failed_attempts, const ::grpc::Status &status) {
int failed_attempts, const ::grpc::Status &status) {
if (!retry_indicator(failed_attempts, status)) {
return optional<Duration>();
}
Expand Down
10 changes: 3 additions & 7 deletions async_grpc/rpc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
*/

#include "async_grpc/rpc.h"
#include "async_grpc/service.h"

#include "async_grpc/common/make_unique.h"
#include "async_grpc/service.h"
#include "glog/logging.h"

namespace async_grpc {
Expand Down Expand Up @@ -320,8 +320,6 @@ bool Rpc::IsAnyEventPending() {

std::weak_ptr<Rpc> Rpc::GetWeakPtr() { return weak_ptr_factory_(this); }

ActiveRpcs::ActiveRpcs() : lock_() {}

void Rpc::InitializeReadersAndWriters(
::grpc::internal::RpcMethod::RpcType rpc_type) {
switch (rpc_type) {
Expand Down Expand Up @@ -349,23 +347,22 @@ void Rpc::InitializeReadersAndWriters(
}
}

ActiveRpcs::ActiveRpcs() {}

ActiveRpcs::~ActiveRpcs() {
common::MutexLocker locker(&lock_);
if (!rpcs_.empty()) {
LOG(FATAL) << "RPCs still in flight!";
}
}

std::shared_ptr<Rpc> ActiveRpcs::Add(std::unique_ptr<Rpc> rpc) {
common::MutexLocker locker(&lock_);
std::shared_ptr<Rpc> shared_ptr_rpc = std::move(rpc);
const auto result = rpcs_.emplace(shared_ptr_rpc.get(), shared_ptr_rpc);
CHECK(result.second) << "RPC already active.";
return shared_ptr_rpc;
}

bool ActiveRpcs::Remove(Rpc* rpc) {
common::MutexLocker locker(&lock_);
auto it = rpcs_.find(rpc);
if (it != rpcs_.end()) {
rpcs_.erase(it);
Expand All @@ -379,7 +376,6 @@ Rpc::WeakPtrFactory ActiveRpcs::GetWeakPtrFactory() {
}

std::weak_ptr<Rpc> ActiveRpcs::GetWeakPtr(Rpc* rpc) {
common::MutexLocker locker(&lock_);
auto it = rpcs_.find(rpc);
CHECK(it != rpcs_.end());
return it->second;
Expand Down
2 changes: 1 addition & 1 deletion async_grpc/rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class Rpc {
bool IsRpcEventPending(Event event);
bool IsAnyEventPending();
void SetEventQueue(EventQueue* event_queue) { event_queue_ = event_queue; }
const EventQueue* event_queue() const { return event_queue_; }
EventQueue* event_queue() { return event_queue_; }
std::weak_ptr<Rpc> GetWeakPtr();
RpcHandlerInterface* handler() { return handler_.get(); }
Expand Down Expand Up @@ -202,7 +203,6 @@ class ActiveRpcs {
private:
std::weak_ptr<Rpc> GetWeakPtr(Rpc* rpc);

common::Mutex lock_;
std::map<Rpc*, std::shared_ptr<Rpc>> rpcs_;
};

Expand Down
29 changes: 14 additions & 15 deletions async_grpc/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ void Server::Builder::SetServerAddress(const std::string& server_address) {
}

void Server::Builder::SetMaxReceiveMessageSize(int max_receive_message_size) {
CHECK_GT(max_receive_message_size, 0) << "max_receive_message_size must be larger than 0.";
CHECK_GT(max_receive_message_size, 0)
<< "max_receive_message_size must be larger than 0.";
options_.max_receive_message_size = max_receive_message_size;
}

void Server::Builder::SetMaxSendMessageSize(int max_send_message_size) {
CHECK_GT(max_send_message_size, 0) << "max_send_message_size must be larger than 0.";
CHECK_GT(max_send_message_size, 0)
<< "max_send_message_size must be larger than 0.";
options_.max_send_message_size = max_send_message_size;
}

Expand All @@ -63,19 +65,19 @@ void Server::Builder::EnableTracing() {
#endif
}

void Server::Builder::DisableTracing() {
options_.enable_tracing = false;
}
void Server::Builder::DisableTracing() { options_.enable_tracing = false; }

void Server::Builder::SetTracingSamplerProbability(double tracing_sampler_probability) {
void Server::Builder::SetTracingSamplerProbability(
double tracing_sampler_probability) {
options_.tracing_sampler_probability = tracing_sampler_probability;
}

void Server::Builder::SetTracingTaskName(const std::string& tracing_task_name) {
options_.tracing_task_name = tracing_task_name;
}

void Server::Builder::SetTracingGcpProjectId(const std::string& tracing_gcp_project_id) {
void Server::Builder::SetTracingGcpProjectId(
const std::string& tracing_gcp_project_id) {
options_.tracing_gcp_project_id = tracing_gcp_project_id;
}

Expand Down Expand Up @@ -125,7 +127,7 @@ void Server::AddService(
const auto result = services_.emplace(
std::piecewise_construct, std::make_tuple(service_name),
std::make_tuple(service_name, rpc_handler_infos,
[this]() { return SelectNextEventQueueRoundRobin(); }));
[this]() { return SelectNextEventQueue(); }));
CHECK(result.second) << "A service named " << service_name
<< " already exists.";
server_builder_.RegisterService(&result.first->second);
Expand All @@ -142,11 +144,9 @@ void Server::RunCompletionQueue(
}
}

EventQueue* Server::SelectNextEventQueueRoundRobin() {
common::MutexLocker locker(&current_event_queue_id_lock_);
current_event_queue_id_ =
(current_event_queue_id_ + 1) % options_.num_event_threads;
return event_queue_threads_.at(current_event_queue_id_).event_queue();
EventQueue* Server::SelectNextEventQueue() {
return event_queue_threads_.at(rand() % event_queue_threads_.size())
.event_queue();
}

void Server::RunEventQueue(EventQueue* event_queue) {
Expand Down Expand Up @@ -178,13 +178,12 @@ void Server::Start() {
}
#endif


// Start the gRPC server process.
server_ = server_builder_.BuildAndStart();

// Start serving all services on all completion queues.
for (auto& service : services_) {
service.second.StartServing(completion_queue_threads_,
service.second.StartServing(event_queue_threads_, completion_queue_threads_,
execution_context_.get());
}

Expand Down
7 changes: 2 additions & 5 deletions async_grpc/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,12 @@
#include "async_grpc/rpc_handler.h"
#include "async_grpc/rpc_service_method_traits.h"
#include "async_grpc/service.h"

#include "grpc++/grpc++.h"

namespace async_grpc {
namespace {

constexpr int kDefaultMaxMessageSize = 10 * 1024 * 1024; // 10 MB
constexpr int kDefaultMaxMessageSize = 10 * 1024 * 1024; // 10 MB
constexpr double kDefaultTracingSamplerProbability = 0.01; // 1 Percent

} // namespace
Expand Down Expand Up @@ -194,7 +193,7 @@ class Server {
Server& operator=(const Server&) = delete;
void RunCompletionQueue(::grpc::ServerCompletionQueue* completion_queue);
void RunEventQueue(Rpc::EventQueue* event_queue);
Rpc::EventQueue* SelectNextEventQueueRoundRobin();
Rpc::EventQueue* SelectNextEventQueue();

Options options_;

Expand All @@ -209,8 +208,6 @@ class Server {

// Threads processing RPC events.
std::vector<EventQueueThread> event_queue_threads_;
common::Mutex current_event_queue_id_lock_;
int current_event_queue_id_ = 0;

// Map of service names to services.
std::map<std::string, Service> services_;
Expand Down
32 changes: 23 additions & 9 deletions async_grpc/service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
* limitations under the License.
*/

#include "async_grpc/server.h"

#include <cstdlib>

#include "async_grpc/server.h"
#include "glog/logging.h"
#include "grpc++/impl/codegen/proto_utils.h"

Expand All @@ -38,15 +37,27 @@ Service::Service(const std::string& service_name,
}

void Service::StartServing(
const std::vector<EventQueueThread>& event_queue_threads,
std::vector<CompletionQueueThread>& completion_queue_threads,
ExecutionContext* execution_context) {
CHECK(active_rpcs_.empty());
int i = 0;

for (const auto& event_queue_thread : event_queue_threads) {
const auto* event_queue = event_queue_thread.event_queue();
// TODO(cschuet): Prettify.
active_rpcs_[event_queue];
LOG(INFO) << "Creating ActiveRpcs";
}

for (const auto& rpc_handler_info : rpc_handler_infos_) {
for (auto& completion_queue_thread : completion_queue_threads) {
std::shared_ptr<Rpc> rpc = active_rpcs_.Add(common::make_unique<Rpc>(
EventQueue* event_queue = event_queue_selector_();
auto& active_rpcs = active_rpcs_.at(event_queue);
std::shared_ptr<Rpc> rpc = active_rpcs.Add(common::make_unique<Rpc>(
i, completion_queue_thread.completion_queue(),
event_queue_selector_(), execution_context, rpc_handler_info.second,
this, active_rpcs_.GetWeakPtrFactory()));
this, active_rpcs.GetWeakPtrFactory()));
rpc->RequestNextMethodInvocation();
}
++i;
Expand Down Expand Up @@ -81,13 +92,13 @@ void Service::HandleNewConnection(Rpc* rpc, bool ok) {
if (ok) {
LOG(WARNING) << "Server shutting down. Refusing to handle new RPCs.";
}
active_rpcs_.Remove(rpc);
active_rpcs_.at(rpc->event_queue()).Remove(rpc);
return;
}

if (!ok) {
LOG(ERROR) << "Failed to establish connection for unknown reason.";
active_rpcs_.Remove(rpc);
active_rpcs_.at(rpc->event_queue()).Remove(rpc);
}

if (ok) {
Expand All @@ -97,8 +108,11 @@ void Service::HandleNewConnection(Rpc* rpc, bool ok) {
// Create new active rpc to handle next connection and register it for the
// incoming connection. Assign event queue in a round-robin fashion.
std::unique_ptr<Rpc> new_rpc = rpc->Clone();
new_rpc->SetEventQueue(event_queue_selector_());
active_rpcs_.Add(std::move(new_rpc))->RequestNextMethodInvocation();
auto* next_event_queue = event_queue_selector_();
new_rpc->SetEventQueue(next_event_queue);
active_rpcs_.at(next_event_queue)
.Add(std::move(new_rpc))
->RequestNextMethodInvocation();
}

void Service::HandleRead(Rpc* rpc, bool ok) {
Expand Down Expand Up @@ -139,7 +153,7 @@ void Service::HandleDone(Rpc* rpc, bool ok) { RemoveIfNotPending(rpc); }

void Service::RemoveIfNotPending(Rpc* rpc) {
if (!rpc->IsAnyEventPending()) {
active_rpcs_.Remove(rpc);
active_rpcs_.at(rpc->event_queue()).Remove(rpc);
}
}

Expand Down
7 changes: 5 additions & 2 deletions async_grpc/service.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#ifndef CPP_GRPC_SERVICE_H
#define CPP_GRPC_SERVICE_H

#include <unordered_map>

#include "async_grpc/completion_queue_thread.h"
#include "async_grpc/event_queue_thread.h"
#include "async_grpc/execution_context.h"
Expand All @@ -38,7 +40,8 @@ class Service : public ::grpc::Service {
Service(const std::string& service_name,
const std::map<std::string, RpcHandlerInfo>& rpc_handlers,
EventQueueSelector event_queue_selector);
void StartServing(std::vector<CompletionQueueThread>& completion_queues,
void StartServing(const std::vector<EventQueueThread>& event_queue_threads,
std::vector<CompletionQueueThread>& completion_queues,
ExecutionContext* execution_context);
void HandleEvent(Rpc::Event event, Rpc* rpc, bool ok);
void StopServing();
Expand All @@ -54,7 +57,7 @@ class Service : public ::grpc::Service {

std::map<std::string, RpcHandlerInfo> rpc_handler_infos_;
EventQueueSelector event_queue_selector_;
ActiveRpcs active_rpcs_;
std::unordered_map<const EventQueue*, ActiveRpcs> active_rpcs_;
bool shutting_down_ = false;
};

Expand Down