diff --git a/cpp2sky/internal/BUILD b/cpp2sky/internal/BUILD index 2769991..b68a5f4 100644 --- a/cpp2sky/internal/BUILD +++ b/cpp2sky/internal/BUILD @@ -10,6 +10,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + "@com_github_grpc_grpc//:grpc++", "@skywalking_data_collect_protocol//language-agent:tracing_protocol_cc_grpc", ], ) diff --git a/cpp2sky/internal/async_client.h b/cpp2sky/internal/async_client.h index a06cfa9..a2f88a7 100644 --- a/cpp2sky/internal/async_client.h +++ b/cpp2sky/internal/async_client.h @@ -14,18 +14,19 @@ #pragma once -#include - #include #include +#include "google/protobuf/message.h" +#include "grpcpp/generic/generic_stub.h" +#include "grpcpp/grpcpp.h" #include "language-agent/Tracing.pb.h" namespace cpp2sky { -using TracerRequestType = skywalking::v3::SegmentObject; -using TracerResponseType = skywalking::v3::Commands; - +/** + * Template base class for gRPC async client. + */ template class AsyncClientBase { public: @@ -36,12 +37,21 @@ class AsyncClientBase { */ virtual void sendMessage(RequestType message) = 0; + /** + * Reset the client. This should be called when the client is no longer + * needed. + */ virtual void resetClient() = 0; }; -using AsyncClient = AsyncClientBase; -using AsyncClientPtr = std::unique_ptr; +template +using AsyncClientBasePtr = + std::unique_ptr>; +/** + * Template base class for gRPC async stream. The stream is used to represent + * a single gRPC stream/request. + */ template class AsyncStreamBase { public: @@ -57,12 +67,52 @@ template using AsyncStreamBasePtr = std::unique_ptr>; -using AsyncStream = AsyncStreamBase; -using AsyncStreamSharedPtr = std::shared_ptr; - +/** + * Tag for async operation. The callback should be called when the operation is + * done. + */ struct AsyncEventTag { std::function callback; }; using AsyncEventTagPtr = std::unique_ptr; +using GrpcClientContextPtr = std::unique_ptr; +using GrpcCompletionQueue = grpc::CompletionQueue; + +/** + * Factory for creating async stream. + */ +template +class AsyncStreamFactoryBase { + public: + virtual ~AsyncStreamFactoryBase() = default; + + using AsyncStreamPtr = AsyncStreamBasePtr; + using GrpcStub = grpc::TemplatedGenericStub; + + virtual AsyncStreamPtr createStream(GrpcClientContextPtr client_ctx, + GrpcStub& stub, GrpcCompletionQueue& cq, + AsyncEventTag& basic_event_tag, + AsyncEventTag& write_event_tag) = 0; +}; + +template +using AsyncStreamFactoryBasePtr = + std::unique_ptr>; + +using TraceRequestType = skywalking::v3::SegmentObject; +using TraceResponseType = skywalking::v3::Commands; + +using TraceAsyncStream = AsyncStreamBase; +using TraceAsyncStreamPtr = + AsyncStreamBasePtr; + +using TraceAsyncStreamFactory = + AsyncStreamFactoryBase; +using TraceAsyncStreamFactoryPtr = + AsyncStreamFactoryBasePtr; + +using TraceAsyncClient = AsyncClientBase; +using TraceAsyncClientPtr = std::unique_ptr; + } // namespace cpp2sky diff --git a/source/grpc_async_client_impl.cc b/source/grpc_async_client_impl.cc index 78faca7..f4f7d6b 100644 --- a/source/grpc_async_client_impl.cc +++ b/source/grpc_async_client_impl.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "grpc_async_client_impl.h" +#include "source/grpc_async_client_impl.h" #include @@ -62,15 +62,50 @@ void EventLoopThread::gogo() { } } -GrpcAsyncSegmentReporterClient::GrpcAsyncSegmentReporterClient( - const std::string& address, const std::string& token, - CredentialsSharedPtr cred) - : stub_(grpc::CreateChannel(address, cred)) { - if (!token.empty()) { - client_ctx_.AddMetadata(AuthenticationKey, token); +TraceAsyncStreamImpl::TraceAsyncStreamImpl(GrpcClientContextPtr client_ctx, + TraceGrpcStub& stub, + GrpcCompletionQueue& cq, + AsyncEventTag& basic_event_tag, + AsyncEventTag& write_event_tag) + : client_ctx_(std::move(client_ctx)), + basic_event_tag_(basic_event_tag), + write_event_tag_(write_event_tag) { + if (client_ctx_ == nullptr) { + client_ctx_.reset(new grpc::ClientContext()); } - basic_event_tag_.reset(new AsyncEventTag{[this](bool ok) { + request_writer_ = + stub.PrepareCall(client_ctx_.get(), TraceCollectMethod, &cq); + request_writer_->StartCall(reinterpret_cast(&basic_event_tag_)); +} + +void TraceAsyncStreamImpl::sendMessage(TraceRequestType message) { + request_writer_->Write(message, reinterpret_cast(&write_event_tag_)); +} + +TraceAsyncStreamPtr TraceAsyncStreamFactoryImpl::createStream( + GrpcClientContextPtr client_ctx, TraceGrpcStub& stub, + GrpcCompletionQueue& cq, AsyncEventTag& basic_event_tag, + AsyncEventTag& write_event_tag) { + return TraceAsyncStreamPtr{new TraceAsyncStreamImpl( + std::move(client_ctx), stub, cq, basic_event_tag, write_event_tag)}; +} + +std::unique_ptr TraceAsyncClientImpl::createClient( + const std::string& address, const std::string& token, + TraceAsyncStreamFactoryPtr factory, CredentialsSharedPtr cred) { + return std::unique_ptr{new TraceAsyncClientImpl( + address, token, std::move(factory), std::move(cred))}; +} + +TraceAsyncClientImpl::TraceAsyncClientImpl(const std::string& address, + const std::string& token, + TraceAsyncStreamFactoryPtr factory, + CredentialsSharedPtr cred) + : token_(token), + stream_factory_(std::move(factory)), + stub_(grpc::CreateChannel(address, cred)) { + basic_event_tag_.callback = [this](bool ok) { if (client_reset_) { return; } @@ -94,9 +129,9 @@ GrpcAsyncSegmentReporterClient::GrpcAsyncSegmentReporterClient( // Reset stream and try to create a new one. startStream(); } - }}); + }; - write_event_tag_.reset(new AsyncEventTag{[this](bool ok) { + write_event_tag_.callback = [this](bool ok) { if (ok) { trace("[Reporter] Stream {} message sending success.", fmt::ptr(this)); messages_sent_++; @@ -106,13 +141,18 @@ GrpcAsyncSegmentReporterClient::GrpcAsyncSegmentReporterClient( } // Delegate the event to basic_event_tag_ to trigger the next task or // reset the stream. - basic_event_tag_->callback(ok); - }}); + basic_event_tag_.callback(ok); + }; + + // If the factory is not provided, use the default one. + if (stream_factory_ == nullptr) { + stream_factory_.reset(new TraceAsyncStreamFactoryImpl()); + } startStream(); } -void GrpcAsyncSegmentReporterClient::sendMessageOnce() { +void TraceAsyncClientImpl::sendMessageOnce() { bool expect_idle = true; if (event_loop_idle_.compare_exchange_strong(expect_idle, false)) { assert(active_stream_ != nullptr); @@ -128,22 +168,31 @@ void GrpcAsyncSegmentReporterClient::sendMessageOnce() { } } -void GrpcAsyncSegmentReporterClient::startStream() { - resetStream(); // Reset stream before creating a new one. +void TraceAsyncClientImpl::startStream() { + if (active_stream_ != nullptr) { + resetStream(); // Reset stream before creating a new one. + } + + // Create the unique client context for the new stream. + // Each stream should have its own context. + auto client_ctx = GrpcClientContextPtr{new grpc::ClientContext()}; + if (!token_.empty()) { + client_ctx->AddMetadata(AuthenticationKey, token_); + } - active_stream_ = std::make_shared( - stub_.PrepareCall(&client_ctx_, TraceCollectMethod, &event_loop_.cq_), - basic_event_tag_.get(), write_event_tag_.get()); + active_stream_ = stream_factory_->createStream( + std::move(client_ctx), stub_, event_loop_.cq_, basic_event_tag_, + write_event_tag_); info("[Reporter] Stream {} has created.", fmt::ptr(active_stream_.get())); } -void GrpcAsyncSegmentReporterClient::resetStream() { +void TraceAsyncClientImpl::resetStream() { info("[Reporter] Stream {} has deleted.", fmt::ptr(active_stream_.get())); active_stream_.reset(); } -void GrpcAsyncSegmentReporterClient::sendMessage(TracerRequestType message) { +void TraceAsyncClientImpl::sendMessage(TraceRequestType message) { messages_total_++; const size_t pending = message_buffer_.size(); @@ -157,17 +206,4 @@ void GrpcAsyncSegmentReporterClient::sendMessage(TracerRequestType message) { sendMessageOnce(); } -SegmentReporterStream::SegmentReporterStream( - TraceReaderWriterPtr request_writer, AsyncEventTag* basic_event_tag, - AsyncEventTag* write_event_tag) - : request_writer_(std::move(request_writer)), - basic_event_tag_(basic_event_tag), - write_event_tag_(write_event_tag) { - request_writer_->StartCall(reinterpret_cast(basic_event_tag_)); -} - -void SegmentReporterStream::sendMessage(TracerRequestType message) { - request_writer_->Write(message, reinterpret_cast(write_event_tag_)); -} - } // namespace cpp2sky diff --git a/source/grpc_async_client_impl.h b/source/grpc_async_client_impl.h index 23537c8..251727b 100644 --- a/source/grpc_async_client_impl.h +++ b/source/grpc_async_client_impl.h @@ -14,14 +14,11 @@ #pragma once -#include -#include - -#include #include #include #include #include +#include #include #include "cpp2sky/config.pb.h" @@ -31,6 +28,14 @@ namespace cpp2sky { +using CredentialsSharedPtr = std::shared_ptr; + +using TraceGrpcStub = + grpc::TemplatedGenericStub; +using TraceReaderWriter = + grpc::ClientAsyncReaderWriter; +using TraceReaderWriterPtr = std::unique_ptr; + class EventLoopThread { public: EventLoopThread() : thread_([this] { this->gogo(); }) {} @@ -53,40 +58,60 @@ class EventLoopThread { void gogo(); }; -using CredentialsSharedPtr = std::shared_ptr; -using TracerReaderWriter = - grpc::ClientAsyncReaderWriter; -using TraceReaderWriterPtr = std::unique_ptr; - -class SegmentReporterStream : public AsyncStream { +class TraceAsyncStreamImpl : public TraceAsyncStream { public: - SegmentReporterStream(TraceReaderWriterPtr request_writer, - AsyncEventTag* basic_event_tag, - AsyncEventTag* write_event_tag); + TraceAsyncStreamImpl(GrpcClientContextPtr client_ctx, TraceGrpcStub& stub, + GrpcCompletionQueue& cq, AsyncEventTag& basic_event_tag, + AsyncEventTag& write_event_tag); // AsyncStream - void sendMessage(TracerRequestType message) override; + void sendMessage(TraceRequestType message) override; private: + GrpcClientContextPtr client_ctx_; TraceReaderWriterPtr request_writer_; - AsyncEventTag* basic_event_tag_; - AsyncEventTag* write_event_tag_; + AsyncEventTag& basic_event_tag_; + AsyncEventTag& write_event_tag_; }; -class GrpcAsyncSegmentReporterClient : public AsyncClient { +class TraceAsyncStreamFactoryImpl : public TraceAsyncStreamFactory { public: - GrpcAsyncSegmentReporterClient(const std::string& address, - const std::string& token, - CredentialsSharedPtr cred); - ~GrpcAsyncSegmentReporterClient() override { + TraceAsyncStreamFactoryImpl() = default; + + TraceAsyncStreamPtr createStream(GrpcClientContextPtr client_ctx, + GrpcStub& stub, GrpcCompletionQueue& cq, + AsyncEventTag& basic_event_tag, + AsyncEventTag& write_event_tag) override; +}; + +class TraceAsyncClientImpl : public TraceAsyncClient { + public: + /** + * Create a new GrpcAsyncSegmentReporterClient. + * + * @param address The address of the server. + * @param token The optional token used to authenticate the client. + * If non-empty token is provided, the client will send the token + * to the server in the metadata. + * @param cred The credentials for creating the channel. + * @param factory The factory function to create the stream from the + * request writer and event tags. In most cases, the default factory + * should be used. + */ + static std::unique_ptr createClient( + const std::string& address, const std::string& token, + TraceAsyncStreamFactoryPtr factory = nullptr, + CredentialsSharedPtr cred = grpc::InsecureChannelCredentials()); + + ~TraceAsyncClientImpl() override { if (!client_reset_) { resetClient(); } } // AsyncClient - void sendMessage(TracerRequestType message) override; + void sendMessage(TraceRequestType message) override; void resetClient() override { // After this is called, no more events will be processed. client_reset_ = true; @@ -96,25 +121,33 @@ class GrpcAsyncSegmentReporterClient : public AsyncClient { } protected: + TraceAsyncClientImpl( + const std::string& address, const std::string& token, + TraceAsyncStreamFactoryPtr factory = nullptr, + CredentialsSharedPtr cred = grpc::InsecureChannelCredentials()); + // Start or re-create the stream that used to send messages. - virtual void startStream(); + void startStream(); void resetStream(); void markEventLoopIdle() { event_loop_idle_.store(true); } void sendMessageOnce(); + const std::string token_; + TraceAsyncStreamFactoryPtr stream_factory_; + TraceGrpcStub stub_; + // This may be operated by multiple threads. std::atomic messages_total_{0}; std::atomic messages_dropped_{0}; std::atomic messages_sent_{0}; EventLoopThread event_loop_; - grpc::ClientContext client_ctx_; std::atomic client_reset_{false}; - ValueBuffer message_buffer_; + ValueBuffer message_buffer_; - AsyncEventTagPtr basic_event_tag_; - AsyncEventTagPtr write_event_tag_; + AsyncEventTag basic_event_tag_; + AsyncEventTag write_event_tag_; // The Write() of the stream could only be called once at a time // until the previous Write() is finished (callback is called). @@ -127,8 +160,7 @@ class GrpcAsyncSegmentReporterClient : public AsyncClient { // occupied by the first operation (startStream). std::atomic event_loop_idle_{false}; - grpc::TemplatedGenericStub stub_; - AsyncStreamSharedPtr active_stream_; + TraceAsyncStreamPtr active_stream_; }; } // namespace cpp2sky diff --git a/source/tracer_impl.cc b/source/tracer_impl.cc index 3a2eec6..e7320c1 100644 --- a/source/tracer_impl.cc +++ b/source/tracer_impl.cc @@ -32,7 +32,8 @@ TracerImpl::TracerImpl(const TracerConfig& config, CredentialsSharedPtr cred) init(config, cred); } -TracerImpl::TracerImpl(const TracerConfig& config, AsyncClientPtr async_client) +TracerImpl::TracerImpl(const TracerConfig& config, + TraceAsyncClientPtr async_client) : async_client_(std::move(async_client)), segment_factory_(config) { init(config, nullptr); } @@ -69,12 +70,11 @@ void TracerImpl::init(const TracerConfig& config, CredentialsSharedPtr cred) { spdlog::set_level(spdlog::level::warn); if (async_client_ == nullptr) { - if (config.protocol() == Protocol::GRPC) { - async_client_.reset(new GrpcAsyncSegmentReporterClient( - config.address(), config.token(), cred)); - } else { - throw TracerException("REST is not supported."); + if (config.protocol() != Protocol::GRPC) { + throw TracerException("Only GRPC is supported."); } + async_client_ = TraceAsyncClientImpl::createClient( + config.address(), config.token(), nullptr, std::move(cred)); } ignore_matcher_.reset(new SuffixMatcher( diff --git a/source/tracer_impl.h b/source/tracer_impl.h index 2be2c7d..febd577 100644 --- a/source/tracer_impl.h +++ b/source/tracer_impl.h @@ -34,8 +34,8 @@ using CdsResponse = skywalking::v3::Commands; class TracerImpl : public Tracer { public: - TracerImpl(const TracerConfig& config, CredentialsSharedPtr cred); - TracerImpl(const TracerConfig& config, AsyncClientPtr async_client); + TracerImpl(const TracerConfig& config, CredentialsSharedPtr credentials); + TracerImpl(const TracerConfig& config, TraceAsyncClientPtr async_client); ~TracerImpl(); TracingContextSharedPtr newContext() override; @@ -46,7 +46,7 @@ class TracerImpl : public Tracer { private: void init(const TracerConfig& config, CredentialsSharedPtr cred); - AsyncClientPtr async_client_; + TraceAsyncClientPtr async_client_; TracingContextFactory segment_factory_; MatcherPtr ignore_matcher_; }; diff --git a/test/grpc_async_client_test.cc b/test/grpc_async_client_test.cc index c3bae59..4016d6d 100644 --- a/test/grpc_async_client_test.cc +++ b/test/grpc_async_client_test.cc @@ -41,10 +41,13 @@ struct TestStats { uint64_t pending_{}; }; -class TestGrpcAsyncSegmentReporterClient - : public GrpcAsyncSegmentReporterClient { +class TestTraceAsyncClient : public TraceAsyncClientImpl { public: - using GrpcAsyncSegmentReporterClient::GrpcAsyncSegmentReporterClient; + TestTraceAsyncClient(const std::string& address, const std::string& token, + TraceAsyncStreamFactoryPtr stream_factory, + CredentialsSharedPtr credentials) + : TraceAsyncClientImpl(address, token, std::move(stream_factory), + std::move(credentials)) {} TestStats getTestStats() const { TestStats stats(messages_total_.load(), messages_dropped_.load(), @@ -52,28 +55,47 @@ class TestGrpcAsyncSegmentReporterClient return stats; } - void notifyWriteEvent(bool success) { write_event_tag_->callback(success); } - void notifyStartEvent(bool success) { basic_event_tag_->callback(success); } + void notifyWriteEvent(bool success) { write_event_tag_.callback(success); } + void notifyStartEvent(bool success) { basic_event_tag_.callback(success); } uint64_t bufferSize() const { return message_buffer_.size(); } +}; - void startStream() override { - resetStream(); - active_stream_ = mock_stream_; +class TestTraceAsyncStreamFactory : public TraceAsyncStreamFactory { + public: + TestTraceAsyncStreamFactory(std::shared_ptr mock_stream) + : mock_stream_(mock_stream) {} + + class TestTraceAsyncStream : public TraceAsyncStream { + public: + TestTraceAsyncStream(std::shared_ptr mock_stream) + : mock_stream_(mock_stream) {} + void sendMessage(TraceRequestType message) override { + mock_stream_->sendMessage(std::move(message)); + } + std::shared_ptr mock_stream_; + }; + + TraceAsyncStreamPtr createStream(GrpcClientContextPtr, GrpcStub&, + GrpcCompletionQueue&, AsyncEventTag&, + AsyncEventTag&) override { + return TraceAsyncStreamPtr{new TestTraceAsyncStream(mock_stream_)}; } - std::shared_ptr mock_stream_ = - std::make_shared(); + std::shared_ptr mock_stream_; }; -class GrpcAsyncSegmentReporterClientTest : public testing::Test { +class TraceAsyncClientImplTest : public testing::Test { public: - GrpcAsyncSegmentReporterClientTest() { - client_.reset(new TestGrpcAsyncSegmentReporterClient( - address_, token_, grpc::InsecureChannelCredentials())); + TraceAsyncClientImplTest() { + client_.reset(new TestTraceAsyncClient( + address_, token_, + TraceAsyncStreamFactoryPtr{ + new TestTraceAsyncStreamFactory(mock_stream_)}, + grpc::InsecureChannelCredentials())); } - ~GrpcAsyncSegmentReporterClientTest() { + ~TraceAsyncClientImplTest() { client_->resetClient(); client_.reset(); } @@ -82,12 +104,15 @@ class GrpcAsyncSegmentReporterClientTest : public testing::Test { std::string address_{"localhost:50051"}; std::string token_{"token"}; - std::unique_ptr client_; + std::shared_ptr mock_stream_ = + std::make_shared(); + + std::unique_ptr client_; }; -TEST_F(GrpcAsyncSegmentReporterClientTest, SendMessageTest) { +TEST_F(TraceAsyncClientImplTest, SendMessageTest) { skywalking::v3::SegmentObject fake_message; - EXPECT_CALL(*client_->mock_stream_, sendMessage(_)).Times(0); + EXPECT_CALL(*mock_stream_, sendMessage(_)).Times(0); client_->sendMessage(fake_message); auto stats = client_->getTestStats(); @@ -109,7 +134,7 @@ TEST_F(GrpcAsyncSegmentReporterClientTest, SendMessageTest) { EXPECT_EQ(stats.pending_, 1); EXPECT_EQ(client_->bufferSize(), 1); - EXPECT_CALL(*client_->mock_stream_, sendMessage(_)); + EXPECT_CALL(*mock_stream_, sendMessage(_)); client_->notifyStartEvent(true); sleep(1); // wait for the event loop to process the event. @@ -137,7 +162,7 @@ TEST_F(GrpcAsyncSegmentReporterClientTest, SendMessageTest) { // Send another message. This time the stream is ready and // previous message is sent successfully. So the new message // should be sent immediately. - EXPECT_CALL(*client_->mock_stream_, sendMessage(_)); + EXPECT_CALL(*mock_stream_, sendMessage(_)); client_->sendMessage(fake_message); sleep(1); // wait for the event loop to process the event. diff --git a/test/mocks.h b/test/mocks.h index bcaaf07..d9f81c1 100644 --- a/test/mocks.h +++ b/test/mocks.h @@ -17,8 +17,6 @@ #include #include -#include - #include "cpp2sky/internal/async_client.h" #include "cpp2sky/internal/random_generator.h" @@ -33,14 +31,14 @@ class MockRandomGenerator : public RandomGenerator { MOCK_METHOD(std::string, uuid, ()); }; -class MockAsyncStream : public AsyncStream { +class MockTraceAsyncStream : public TraceAsyncStream { public: - MOCK_METHOD(void, sendMessage, (TracerRequestType)); + MOCK_METHOD(void, sendMessage, (TraceRequestType)); }; -class MockAsyncClient : public AsyncClient { +class MockTraceAsyncClient : public TraceAsyncClient { public: - MOCK_METHOD(void, sendMessage, (TracerRequestType)); + MOCK_METHOD(void, sendMessage, (TraceRequestType)); MOCK_METHOD(void, resetClient, ()); }; diff --git a/test/tracer_test.cc b/test/tracer_test.cc index 9ba8ed7..462eb07 100644 --- a/test/tracer_test.cc +++ b/test/tracer_test.cc @@ -27,8 +27,8 @@ TEST(TracerTest, MatchedOpShouldIgnored) { TracerConfig config; *config.add_ignore_operation_name_suffix() = "/ignored"; - TracerImpl tracer(config, - AsyncClientPtr{new testing::NiceMock()}); + TracerImpl tracer(config, TraceAsyncClientPtr{ + new testing::NiceMock()}); auto context = tracer.newContext(); auto span = context->createEntrySpan(); @@ -41,8 +41,8 @@ TEST(TracerTest, MatchedOpShouldIgnored) { TEST(TracerTest, NotClosedSpanExists) { TracerConfig config; - TracerImpl tracer(config, - AsyncClientPtr{new testing::NiceMock()}); + TracerImpl tracer(config, TraceAsyncClientPtr{ + new testing::NiceMock()}); auto context = tracer.newContext(); auto span = context->createEntrySpan(); @@ -54,8 +54,8 @@ TEST(TracerTest, NotClosedSpanExists) { TEST(TracerTest, Success) { TracerConfig config; - auto mock_reporter = std::unique_ptr{ - new testing::NiceMock()}; + auto mock_reporter = std::unique_ptr{ + new testing::NiceMock()}; EXPECT_CALL(*mock_reporter, sendMessage(_)); TracerImpl tracer(config, std::move(mock_reporter));