Skip to content

Commit

Permalink
fix unit test
Browse files Browse the repository at this point in the history
Signed-off-by: wbpcode <[email protected]>
  • Loading branch information
wbpcode committed Aug 15, 2024
1 parent 7b73036 commit c94a4d9
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 113 deletions.
1 change: 1 addition & 0 deletions cpp2sky/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
"@com_github_grpc_grpc//:grpc++",
"@skywalking_data_collect_protocol//language-agent:tracing_protocol_cc_grpc",
],
)
Expand Down
70 changes: 60 additions & 10 deletions cpp2sky/internal/async_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@

#pragma once

#include <google/protobuf/message.h>

#include <functional>
#include <memory>

#include "google/protobuf/message.h"
#include "grpcpp/generic/generic_stub.h"
#include "grpcpp/grpcpp.h"
#include "language-agent/Tracing.pb.h"

namespace cpp2sky {

using TracerRequestType = skywalking::v3::SegmentObject;
using TracerResponseType = skywalking::v3::Commands;

/**
* Template base class for gRPC async client.
*/
template <class RequestType, class ResponseType>
class AsyncClientBase {
public:
Expand All @@ -36,12 +37,21 @@ class AsyncClientBase {
*/
virtual void sendMessage(RequestType message) = 0;

/**
* Reset the client. This should be called when the client is no longer
* needed.
*/
virtual void resetClient() = 0;
};

using AsyncClient = AsyncClientBase<TracerRequestType, TracerResponseType>;
using AsyncClientPtr = std::unique_ptr<AsyncClient>;
template <class RequestType, class ResponseType>
using AsyncClientBasePtr =
std::unique_ptr<AsyncClientBase<RequestType, ResponseType>>;

/**
* Template base class for gRPC async stream. The stream is used to represent
* a single gRPC stream/request.
*/
template <class RequestType, class ResponseType>
class AsyncStreamBase {
public:
Expand All @@ -57,12 +67,52 @@ template <class RequestType, class ResponseType>
using AsyncStreamBasePtr =
std::unique_ptr<AsyncStreamBase<RequestType, ResponseType>>;

using AsyncStream = AsyncStreamBase<TracerRequestType, TracerResponseType>;
using AsyncStreamSharedPtr = std::shared_ptr<AsyncStream>;

/**
* Tag for async operation. The callback should be called when the operation is
* done.
*/
struct AsyncEventTag {
std::function<void(bool)> callback;
};
using AsyncEventTagPtr = std::unique_ptr<AsyncEventTag>;

using GrpcClientContextPtr = std::unique_ptr<grpc::ClientContext>;
using GrpcCompletionQueue = grpc::CompletionQueue;

/**
* Factory for creating async stream.
*/
template <class RequestType, class ResponseType>
class AsyncStreamFactoryBase {
public:
virtual ~AsyncStreamFactoryBase() = default;

using AsyncStreamPtr = AsyncStreamBasePtr<RequestType, ResponseType>;
using GrpcStub = grpc::TemplatedGenericStub<RequestType, ResponseType>;

virtual AsyncStreamPtr createStream(GrpcClientContextPtr client_ctx,
GrpcStub& stub, GrpcCompletionQueue& cq,
AsyncEventTag& basic_event_tag,
AsyncEventTag& write_event_tag) = 0;
};

template <class RequestType, class ResponseType>
using AsyncStreamFactoryBasePtr =
std::unique_ptr<AsyncStreamFactoryBase<RequestType, ResponseType>>;

using TraceRequestType = skywalking::v3::SegmentObject;
using TraceResponseType = skywalking::v3::Commands;

using TraceAsyncStream = AsyncStreamBase<TraceRequestType, TraceResponseType>;
using TraceAsyncStreamPtr =
AsyncStreamBasePtr<TraceRequestType, TraceResponseType>;

using TraceAsyncStreamFactory =
AsyncStreamFactoryBase<TraceRequestType, TraceResponseType>;
using TraceAsyncStreamFactoryPtr =
AsyncStreamFactoryBasePtr<TraceRequestType, TraceResponseType>;

using TraceAsyncClient = AsyncClientBase<TraceRequestType, TraceResponseType>;
using TraceAsyncClientPtr = std::unique_ptr<TraceAsyncClient>;

} // namespace cpp2sky
102 changes: 69 additions & 33 deletions source/grpc_async_client_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "grpc_async_client_impl.h"
#include "source/grpc_async_client_impl.h"

#include <sys/types.h>

Expand Down Expand Up @@ -62,15 +62,50 @@ void EventLoopThread::gogo() {
}
}

GrpcAsyncSegmentReporterClient::GrpcAsyncSegmentReporterClient(
const std::string& address, const std::string& token,
CredentialsSharedPtr cred)
: stub_(grpc::CreateChannel(address, cred)) {
if (!token.empty()) {
client_ctx_.AddMetadata(AuthenticationKey, token);
TraceAsyncStreamImpl::TraceAsyncStreamImpl(GrpcClientContextPtr client_ctx,
TraceGrpcStub& stub,
GrpcCompletionQueue& cq,
AsyncEventTag& basic_event_tag,
AsyncEventTag& write_event_tag)
: client_ctx_(std::move(client_ctx)),
basic_event_tag_(basic_event_tag),
write_event_tag_(write_event_tag) {
if (client_ctx_ == nullptr) {
client_ctx_.reset(new grpc::ClientContext());
}

basic_event_tag_.reset(new AsyncEventTag{[this](bool ok) {
request_writer_ =
stub.PrepareCall(client_ctx_.get(), TraceCollectMethod, &cq);
request_writer_->StartCall(reinterpret_cast<void*>(&basic_event_tag_));
}

void TraceAsyncStreamImpl::sendMessage(TraceRequestType message) {
request_writer_->Write(message, reinterpret_cast<void*>(&write_event_tag_));
}

TraceAsyncStreamPtr TraceAsyncStreamFactoryImpl::createStream(
GrpcClientContextPtr client_ctx, TraceGrpcStub& stub,
GrpcCompletionQueue& cq, AsyncEventTag& basic_event_tag,
AsyncEventTag& write_event_tag) {
return TraceAsyncStreamPtr{new TraceAsyncStreamImpl(
std::move(client_ctx), stub, cq, basic_event_tag, write_event_tag)};
}

std::unique_ptr<TraceAsyncClientImpl> TraceAsyncClientImpl::createClient(
const std::string& address, const std::string& token,
TraceAsyncStreamFactoryPtr factory, CredentialsSharedPtr cred) {
return std::unique_ptr<TraceAsyncClientImpl>{new TraceAsyncClientImpl(
address, token, std::move(factory), std::move(cred))};
}

TraceAsyncClientImpl::TraceAsyncClientImpl(const std::string& address,
const std::string& token,
TraceAsyncStreamFactoryPtr factory,
CredentialsSharedPtr cred)
: token_(token),
stream_factory_(std::move(factory)),
stub_(grpc::CreateChannel(address, cred)) {
basic_event_tag_.callback = [this](bool ok) {
if (client_reset_) {
return;
}
Expand All @@ -94,9 +129,9 @@ GrpcAsyncSegmentReporterClient::GrpcAsyncSegmentReporterClient(
// Reset stream and try to create a new one.
startStream();
}
}});
};

write_event_tag_.reset(new AsyncEventTag{[this](bool ok) {
write_event_tag_.callback = [this](bool ok) {
if (ok) {
trace("[Reporter] Stream {} message sending success.", fmt::ptr(this));
messages_sent_++;
Expand All @@ -106,13 +141,18 @@ GrpcAsyncSegmentReporterClient::GrpcAsyncSegmentReporterClient(
}
// Delegate the event to basic_event_tag_ to trigger the next task or
// reset the stream.
basic_event_tag_->callback(ok);
}});
basic_event_tag_.callback(ok);
};

// If the factory is not provided, use the default one.
if (stream_factory_ == nullptr) {
stream_factory_.reset(new TraceAsyncStreamFactoryImpl());
}

startStream();
}

void GrpcAsyncSegmentReporterClient::sendMessageOnce() {
void TraceAsyncClientImpl::sendMessageOnce() {
bool expect_idle = true;
if (event_loop_idle_.compare_exchange_strong(expect_idle, false)) {
assert(active_stream_ != nullptr);
Expand All @@ -128,22 +168,31 @@ void GrpcAsyncSegmentReporterClient::sendMessageOnce() {
}
}

void GrpcAsyncSegmentReporterClient::startStream() {
resetStream(); // Reset stream before creating a new one.
void TraceAsyncClientImpl::startStream() {
if (active_stream_ != nullptr) {
resetStream(); // Reset stream before creating a new one.
}

// Create the unique client context for the new stream.
// Each stream should have its own context.
auto client_ctx = GrpcClientContextPtr{new grpc::ClientContext()};
if (!token_.empty()) {
client_ctx->AddMetadata(AuthenticationKey, token_);
}

active_stream_ = std::make_shared<SegmentReporterStream>(
stub_.PrepareCall(&client_ctx_, TraceCollectMethod, &event_loop_.cq_),
basic_event_tag_.get(), write_event_tag_.get());
active_stream_ = stream_factory_->createStream(
std::move(client_ctx), stub_, event_loop_.cq_, basic_event_tag_,
write_event_tag_);

info("[Reporter] Stream {} has created.", fmt::ptr(active_stream_.get()));
}

void GrpcAsyncSegmentReporterClient::resetStream() {
void TraceAsyncClientImpl::resetStream() {
info("[Reporter] Stream {} has deleted.", fmt::ptr(active_stream_.get()));
active_stream_.reset();
}

void GrpcAsyncSegmentReporterClient::sendMessage(TracerRequestType message) {
void TraceAsyncClientImpl::sendMessage(TraceRequestType message) {
messages_total_++;

const size_t pending = message_buffer_.size();
Expand All @@ -157,17 +206,4 @@ void GrpcAsyncSegmentReporterClient::sendMessage(TracerRequestType message) {
sendMessageOnce();
}

SegmentReporterStream::SegmentReporterStream(
TraceReaderWriterPtr request_writer, AsyncEventTag* basic_event_tag,
AsyncEventTag* write_event_tag)
: request_writer_(std::move(request_writer)),
basic_event_tag_(basic_event_tag),
write_event_tag_(write_event_tag) {
request_writer_->StartCall(reinterpret_cast<void*>(basic_event_tag_));
}

void SegmentReporterStream::sendMessage(TracerRequestType message) {
request_writer_->Write(message, reinterpret_cast<void*>(write_event_tag_));
}

} // namespace cpp2sky
Loading

0 comments on commit c94a4d9

Please sign in to comment.