diff --git a/common/zmqclient.cpp b/common/zmqclient.cpp index 5a84160e9..45f118990 100644 --- a/common/zmqclient.cpp +++ b/common/zmqclient.cpp @@ -25,6 +25,12 @@ ZmqClient::ZmqClient(const std::string& endpoint, const std::string& vrf) initialize(endpoint, vrf); } +ZmqClient::ZmqClient(const std::string& endpoint, uint32_t waitTimeMs) +: m_waitTimeMs(waitTimeMs) +{ + initialize(endpoint); +} + ZmqClient::~ZmqClient() { std::lock_guard lock(m_socketMutex); @@ -55,7 +61,7 @@ void ZmqClient::initialize(const std::string& endpoint, const std::string& vrf) connect(); } - + bool ZmqClient::isConnected() { return m_connected; @@ -137,7 +143,7 @@ void ZmqClient::sendMsg( int zmq_err = 0; int retry_delay = 10; int rc = 0; - for (int i = 0; i <= MQ_MAX_RETRY; ++i) + for (int i = 0; i <= MQ_MAX_RETRY; ++i) { { // ZMQ socket is not thread safe: http://api.zeromq.org/2-1:zmq @@ -146,7 +152,6 @@ void ZmqClient::sendMsg( // Use none block mode to use all bandwidth: http://api.zeromq.org/2-1%3Azmq-send rc = zmq_send(m_socket, m_sendbuffer.data(), serializedlen, ZMQ_NOBLOCK); } - if (rc >= 0) { SWSS_LOG_DEBUG("zmq sended %d bytes", serializedlen); @@ -197,4 +202,11 @@ void ZmqClient::sendMsg( throw system_error(make_error_code(errc::io_error), message); } +// TODO: To be implemented later, required for ZMQ_CLIENT & ZMQ_SERVER +// socket types in response path. +bool ZmqClient::wait( + const std::string &dbName, const std::string &tableName, + const std::vector> &kcos) { + return false; +} } diff --git a/common/zmqclient.h b/common/zmqclient.h index adc36b053..fdfe9e343 100644 --- a/common/zmqclient.h +++ b/common/zmqclient.h @@ -12,8 +12,10 @@ namespace swss { class ZmqClient { public: + ZmqClient(const std::string& endpoint); ZmqClient(const std::string& endpoint, const std::string& vrf); + ZmqClient(const std::string& endpoint, uint32_t waitTimeMs); ~ZmqClient(); bool isConnected(); @@ -23,8 +25,13 @@ class ZmqClient void sendMsg(const std::string& dbName, const std::string& tableName, const std::vector& kcos); + + bool wait(const std::string& dbName, + const std::string& tableName, + const std::vector>& kcos); + private: - void initialize(const std::string& endpoint, const std::string& vrf); + void initialize(const std::string& endpoint, const std::string& vrf = ""); std::string m_endpoint; @@ -36,8 +43,10 @@ class ZmqClient bool m_connected; + uint32_t m_waitTimeMs; + std::mutex m_socketMutex; - + std::vector m_sendbuffer; }; diff --git a/common/zmqproducerstatetable.cpp b/common/zmqproducerstatetable.cpp index e2a31446b..6260f4767 100644 --- a/common/zmqproducerstatetable.cpp +++ b/common/zmqproducerstatetable.cpp @@ -1,18 +1,18 @@ -#include -#include -#include -#include +#include "zmqproducerstatetable.h" +#include "binaryserializer.h" +#include "redisapi.h" +#include "redispipeline.h" +#include "redisreply.h" +#include "table.h" +#include "zmqconsumerstatetable.h" #include #include #include +#include +#include +#include +#include #include -#include "redisreply.h" -#include "table.h" -#include "redisapi.h" -#include "redispipeline.h" -#include "zmqproducerstatetable.h" -#include "zmqconsumerstatetable.h" -#include "binaryserializer.h" using namespace std; @@ -164,6 +164,13 @@ void ZmqProducerStateTable::send(const std::vector &kcos } } +bool ZmqProducerStateTable::wait(const std::string& dbName, + const std::string& tableName, + const std::vector>& kcos) +{ + return m_zmqClient.wait(dbName, tableName, kcos); +} + size_t ZmqProducerStateTable::dbUpdaterQueueSize() { if (m_asyncDBUpdater == nullptr) diff --git a/common/zmqproducerstatetable.h b/common/zmqproducerstatetable.h index 015419bd2..09778d47a 100644 --- a/common/zmqproducerstatetable.h +++ b/common/zmqproducerstatetable.h @@ -37,6 +37,11 @@ class ZmqProducerStateTable : public ProducerStateTable // Batched send that can include both SET and DEL requests. virtual void send(const std::vector &kcos); + // To wait for the response from the peer. + virtual bool wait(const std::string& dbName, + const std::string& tableName, + const std::vector>& kcos); + size_t dbUpdaterQueueSize(); private: void initialize(DBConnector *db, const std::string &tableName, bool dbPersistence); diff --git a/common/zmqserver.cpp b/common/zmqserver.cpp index dca107405..a6383866a 100644 --- a/common/zmqserver.cpp +++ b/common/zmqserver.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -20,6 +21,7 @@ ZmqServer::ZmqServer(const std::string& endpoint, const std::string& vrf) : m_endpoint(endpoint), m_vrf(vrf) { + connect(); m_buffer.resize(MQ_RESPONSE_MAX_COUNT); m_runThread = true; m_mqPollThread = std::make_shared(&ZmqServer::mqPollThread, this); @@ -31,6 +33,33 @@ ZmqServer::~ZmqServer() { m_runThread = false; m_mqPollThread->join(); + + zmq_close(m_socket); + zmq_ctx_destroy(m_context); +} + +void ZmqServer::connect() +{ + SWSS_LOG_ENTER(); + m_context = zmq_ctx_new(); + m_socket = zmq_socket(m_context, ZMQ_PULL); + + // Increase recv buffer for use all bandwidth: http://api.zeromq.org/4-2:zmq-setsockopt + int high_watermark = MQ_WATERMARK; + zmq_setsockopt(m_socket, ZMQ_RCVHWM, &high_watermark, sizeof(high_watermark)); + + if (!m_vrf.empty()) + { + zmq_setsockopt(m_socket, ZMQ_BINDTODEVICE, m_vrf.c_str(), m_vrf.length()); + } + + int rc = zmq_bind(m_socket, m_endpoint.c_str()); + if (rc != 0) + { + SWSS_LOG_THROW("zmq_bind failed on endpoint: %s, zmqerrno: %d", + m_endpoint.c_str(), + zmq_errno()); + } } void ZmqServer::registerMessageHandler( @@ -90,32 +119,10 @@ void ZmqServer::mqPollThread() SWSS_LOG_ENTER(); SWSS_LOG_NOTICE("mqPollThread begin"); - // Producer/Consumer state table are n:1 mapping, so need use PUSH/PULL pattern http://api.zeromq.org/master:zmq-socket - void* context = zmq_ctx_new();; - void* socket = zmq_socket(context, ZMQ_PULL); - - // Increase recv buffer for use all bandwidth: http://api.zeromq.org/4-2:zmq-setsockopt - int high_watermark = MQ_WATERMARK; - zmq_setsockopt(socket, ZMQ_RCVHWM, &high_watermark, sizeof(high_watermark)); - - if (!m_vrf.empty()) - { - zmq_setsockopt(socket, ZMQ_BINDTODEVICE, m_vrf.c_str(), m_vrf.length()); - } - - int rc = zmq_bind(socket, m_endpoint.c_str()); - if (rc != 0) - { - SWSS_LOG_THROW("zmq_bind failed on endpoint: %s, zmqerrno: %d, message: %s", - m_endpoint.c_str(), - zmq_errno(), - strerror(zmq_errno())); - } - // zmq_poll will use less CPU zmq_pollitem_t poll_item; poll_item.fd = 0; - poll_item.socket = socket; + poll_item.socket = m_socket; poll_item.events = ZMQ_POLLIN; poll_item.revents = 0; @@ -123,7 +130,7 @@ void ZmqServer::mqPollThread() while (m_runThread) { // receive message - rc = zmq_poll(&poll_item, 1, 1000); + auto rc = zmq_poll(&poll_item, 1, 1000); if (rc == 0 || !(poll_item.revents & ZMQ_POLLIN)) { // timeout or other event @@ -132,7 +139,7 @@ void ZmqServer::mqPollThread() } // receive message - rc = zmq_recv(socket, m_buffer.data(), MQ_RESPONSE_MAX_COUNT, ZMQ_DONTWAIT); + rc = zmq_recv(m_socket, m_buffer.data(), MQ_RESPONSE_MAX_COUNT, ZMQ_DONTWAIT); if (rc < 0) { int zmq_err = zmq_errno(); @@ -160,11 +167,14 @@ void ZmqServer::mqPollThread() // deserialize and write to redis: handleReceivedData(m_buffer.data(), rc); } - - zmq_close(socket); - zmq_ctx_destroy(context); - SWSS_LOG_NOTICE("mqPollThread end"); } +// TODO: To be implemented later, required for ZMQ_CLIENT & ZMQ_SERVER +// socket types in response path. +void ZmqServer::sendMsg( + const std::string &dbName, const std::string &tableName, + const std::vector &values) { + return; +} } diff --git a/common/zmqserver.h b/common/zmqserver.h index 8afe18d7c..1b78b7a25 100644 --- a/common/zmqserver.h +++ b/common/zmqserver.h @@ -39,7 +39,13 @@ class ZmqServer const std::string tableName, ZmqMessageHandler* handler); + void sendMsg(const std::string& dbName, const std::string& tableName, + const std::vector& values); + private: + + void connect(); + void handleReceivedData(const char* buffer, const size_t size); void mqPollThread(); @@ -56,6 +62,10 @@ class ZmqServer std::string m_vrf; + void* m_context; + + void* m_socket; + std::map> m_HandlerMap; }; diff --git a/tests/c_api_ut.cpp b/tests/c_api_ut.cpp index 90fffdaa2..edee222fa 100644 --- a/tests/c_api_ut.cpp +++ b/tests/c_api_ut.cpp @@ -324,6 +324,7 @@ TEST(c_api, ZmqConsumerProducerStateTable) { SWSSZmqProducerStateTable_set(pst, arr.data[i].key, arr.data[i].fieldValues); else SWSSZmqClient_sendMsg(cli, "TEST_DB", "mytable", arr); + sleep(2); ASSERT_EQ(SWSSZmqConsumerStateTable_readData(cst, 1500, true), SWSSSelectResult_DATA); arr = SWSSZmqConsumerStateTable_pops(cst); @@ -362,6 +363,7 @@ TEST(c_api, ZmqConsumerProducerStateTable) { SWSSZmqProducerStateTable_del(pst, arr.data[i].key); else SWSSZmqClient_sendMsg(cli, "TEST_DB", "mytable", arr); + sleep(2); ASSERT_EQ(SWSSZmqConsumerStateTable_readData(cst, 500, true), SWSSSelectResult_DATA); arr = SWSSZmqConsumerStateTable_pops(cst); diff --git a/tests/zmq_state_ut.cpp b/tests/zmq_state_ut.cpp index 56a8299f9..2b0b60d73 100644 --- a/tests/zmq_state_ut.cpp +++ b/tests/zmq_state_ut.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "gtest/gtest.h" #include "common/dbconnector.h" #include "common/notificationconsumer.h" @@ -14,6 +15,7 @@ #include "common/zmqclient.h" #include "common/zmqproducerstatetable.h" #include "common/zmqconsumerstatetable.h" +#include "common/binaryserializer.h" using namespace std; using namespace swss; @@ -257,6 +259,9 @@ static void consumerWorker(string tableName, string endpoint, bool dbPersistence } } + // Wait for some time to write into the DB. + sleep(3); + allDataReceived = true; if (dbPersistence) @@ -288,6 +293,9 @@ static void testMethod(bool producerPersistence) // start consumer first, SHM can only have 1 consumer per table. thread *consumerThread = new thread(consumerWorker, testTableName, pullEndpoint, !producerPersistence); + // Wait for the consumer to start. + sleep(1); + cout << "Starting " << NUMBER_OF_THREADS << " producers" << endl; /* Starting the producer before the producer */ for (int i = 0; i < NUMBER_OF_THREADS; i++) @@ -351,6 +359,9 @@ static void testBatchMethod(bool producerPersistence) // start consumer first, SHM can only have 1 consumer per table. thread *consumerThread = new thread(consumerWorker, testTableName, pullEndpoint, !producerPersistence); + // Wait for the consumer to start. + sleep(1); + cout << "Starting " << NUMBER_OF_THREADS << " producers" << endl; /* Starting the producer before the producer */ for (int i = 0; i < NUMBER_OF_THREADS; i++) @@ -465,3 +476,96 @@ TEST(ZmqProducerStateTableDeleteAfterSend, test) table.getKeys(keys); EXPECT_EQ(keys.front(), testKey); } + +static bool zmq_done = false; + +static void zmqConsumerWorker(string tableName, string endpoint, + bool dbPersistence) { + cout << "Consumer thread started: " << tableName << endl; + DBConnector db(TEST_DB, 0, true); + ZmqServer server(endpoint, ""); + ZmqConsumerStateTable c(&db, tableName, server, 128, 0, dbPersistence); + // validate received data + std::vector values; + values.push_back(KeyOpFieldsValuesTuple{ + "k", SET_COMMAND, + std::vector{FieldValueTuple{"f", "v"}}}); + + while (!zmq_done) { + sleep(2); + std::string recDbName, recTableName; + std::vector> recKcos; + std::vector deserializedKcos; + + BinarySerializer::deserializeBuffer(server.m_buffer.data(), + server.m_buffer.size(), recDbName, + recTableName, recKcos); + + for (auto kcoPtr : recKcos) + { + deserializedKcos.push_back(*kcoPtr); + } + EXPECT_EQ(recDbName, TEST_DB); + EXPECT_EQ(recTableName, tableName); + EXPECT_EQ(deserializedKcos, values); + } + + allDataReceived = true; + if (dbPersistence) + { + // wait all persist data write to redis + while (c.dbUpdaterQueueSize() > 0) + { + sleep(1); + } + } + + zmq_done = true; + cout << "Consumer thread ended: " << tableName << endl; +} + +static void ZmqWithResponse(bool producerPersistence) +{ + std::string testTableName = "ZMQ_PROD_CONS_UT"; + std::string pushEndpoint = "tcp://localhost:1234"; + std::string pullEndpoint = "tcp://*:1234"; + // start consumer first, SHM can only have 1 consumer per table. + thread *consumerThread = new thread(zmqConsumerWorker, testTableName, pullEndpoint, !producerPersistence); + + // Wait for the consumer to be ready. + sleep(1); + DBConnector db(TEST_DB, 0, true); + ZmqClient client(pushEndpoint, 3000); + ZmqProducerStateTable p(&db, testTableName, client, true); + std::vector kcos; + kcos.push_back(KeyOpFieldsValuesTuple{"k", SET_COMMAND, std::vector{FieldValueTuple{"f", "v"}}}); + for (int i = 0; i < 3; ++i) { + p.send(kcos); + } + + zmq_done = true; + consumerThread->join(); + delete consumerThread; +} + +TEST(ZmqWithResponse, test) +{ + // test with persist by consumer + ZmqWithResponse(false); +} + +TEST(ZmqWithResponseClientError, test) +{ + std::string testTableName = "ZMQ_PROD_CONS_UT"; + std::string pushEndpoint = "tcp://localhost:1234"; + DBConnector db(TEST_DB, 0, true); + ZmqClient client(pushEndpoint, 3000); + ZmqProducerStateTable p(&db, testTableName, client, true); + std::vector kcos; + kcos.push_back(KeyOpFieldsValuesTuple{"k", SET_COMMAND, std::vector{}}); + std::vector> kcosPtr; + std::string dbName, tableName; + p.send(kcos); + // Wait will timeout without server reply. + EXPECT_FALSE(p.wait(dbName, tableName, kcosPtr)); +}