diff --git a/common/asyncdbupdater.h b/common/asyncdbupdater.h index 4826661a..25c6ef70 100644 --- a/common/asyncdbupdater.h +++ b/common/asyncdbupdater.h @@ -6,7 +6,6 @@ #include "dbconnector.h" #include "table.h" -#define MQ_RESPONSE_MAX_COUNT (4*1024*1024) #define MQ_SIZE 100 #define MQ_MAX_RETRY 10 #define MQ_POLL_TIMEOUT (1000) diff --git a/common/binaryserializer.h b/common/binaryserializer.h index 18fc3472..413ca501 100644 --- a/common/binaryserializer.h +++ b/common/binaryserializer.h @@ -3,6 +3,8 @@ #include "common/armhelper.h" +#include + using namespace std; namespace swss { @@ -12,27 +14,35 @@ class BinarySerializer { static size_t serializeBuffer( const char* buffer, const size_t size, - const std::string& key, - const std::vector& values, - const std::string& command, const std::string& dbName, - const std::string& tableName) + const std::string& tableName, + const std::vector& kcos) { auto tmpSerializer = BinarySerializer(buffer, size); + // Set the first pair as DB name and table name. tmpSerializer.setKeyAndValue( dbName.c_str(), dbName.length(), tableName.c_str(), tableName.length()); - tmpSerializer.setKeyAndValue( - key.c_str(), key.length(), - command.c_str(), command.length()); - for (auto& kvp : values) + for (auto& kco : kcos) { - auto& field = fvField(kvp); - auto& value = fvValue(kvp); + auto& key = kfvKey(kco); + auto& fvs = kfvFieldsValues(kco); + std::string fvs_len = std::to_string(fvs.size()); + // For each request, the first pair is the key and the number of attributes, + // followed by the attribute pairs. + // The operation is not set, when there is no attribute, it is a DEL request. tmpSerializer.setKeyAndValue( - field.c_str(), field.length(), - value.c_str(), value.length()); + key.c_str(), key.length(), + fvs_len.c_str(), fvs_len.length()); + for (auto& fv : fvs) + { + auto& field = fvField(fv); + auto& value = fvValue(fv); + tmpSerializer.setKeyAndValue( + field.c_str(), field.length(), + value.c_str(), value.length()); + } } return tmpSerializer.finalize(); @@ -88,6 +98,56 @@ class BinarySerializer { } } + static void deserializeBuffer( + const char* buffer, + const size_t size, + std::string& dbName, + std::string& tableName, + std::vector>& kcos) + { + std::vector values; + deserializeBuffer(buffer, size, values); + int fvs_size = -1; + KeyOpFieldsValuesTuple kco; + auto& key = kfvKey(kco); + auto& op = kfvOp(kco); + auto& fvs = kfvFieldsValues(kco); + for (auto& fv : values) + { + auto& field = fvField(fv); + auto& value = fvValue(fv); + // The first pair is the DB name and the table name. + if (fvs_size < 0) + { + dbName = field; + tableName = value; + fvs_size = 0; + continue; + } + // This is the beginning of a request. + // The first pair is the key and the number of attributes. + // If the attribute count is zero, it is a DEL request. + if (fvs_size == 0) + { + key = field; + fvs_size = std::stoi(value); + op = (fvs_size == 0) ? DEL_COMMAND : SET_COMMAND; + fvs.clear(); + } + // This is an attribut pair. + else + { + fvs.push_back(fv); + --fvs_size; + } + // We got the last attribut pair. This is the end of a request. + if (fvs_size == 0) + { + kcos.push_back(std::make_shared(kco)); + } + } + } + private: const char* m_buffer; const size_t m_buffer_size; diff --git a/common/zmqclient.cpp b/common/zmqclient.cpp index dc7f8d07..e6cb07da 100644 --- a/common/zmqclient.cpp +++ b/common/zmqclient.cpp @@ -103,21 +103,24 @@ void ZmqClient::connect() } void ZmqClient::sendMsg( - const std::string& key, - const std::vector& values, - const std::string& command, const std::string& dbName, const std::string& tableName, + const std::vector& kcos, std::vector& sendbuffer) { int serializedlen = (int)BinarySerializer::serializeBuffer( sendbuffer.data(), sendbuffer.size(), - key, - values, - command, dbName, - tableName); + tableName, + kcos); + + if (serializedlen >= MQ_RESPONSE_MAX_COUNT) + { + SWSS_LOG_THROW("ZmqClient sendMsg message was too big (buffer size %d bytes, got %d), reduce the message size, message DROPPED", + MQ_RESPONSE_MAX_COUNT, + serializedlen); + } SWSS_LOG_DEBUG("sending: %d", serializedlen); int zmq_err = 0; diff --git a/common/zmqclient.h b/common/zmqclient.h index efe33bd5..3f56cc29 100644 --- a/common/zmqclient.h +++ b/common/zmqclient.h @@ -19,11 +19,9 @@ class ZmqClient void connect(); - void sendMsg(const std::string& key, - const std::vector& values, - const std::string& command, - const std::string& dbName, + void sendMsg(const std::string& dbName, const std::string& tableName, + const std::vector& kcos, std::vector& sendbuffer); private: void initialize(const std::string& endpoint); diff --git a/common/zmqconsumerstatetable.cpp b/common/zmqconsumerstatetable.cpp index 5795f1fa..5f58482f 100644 --- a/common/zmqconsumerstatetable.cpp +++ b/common/zmqconsumerstatetable.cpp @@ -39,26 +39,28 @@ ZmqConsumerStateTable::ZmqConsumerStateTable(DBConnector *db, const std::string SWSS_LOG_DEBUG("ZmqConsumerStateTable ctor tableName: %s", tableName.c_str()); } -void ZmqConsumerStateTable::handleReceivedData(std::shared_ptr pkco) +void ZmqConsumerStateTable::handleReceivedData(const std::vector> &kcos) { - std::shared_ptr clone = nullptr; - if (m_asyncDBUpdater != nullptr) + for (auto kco : kcos) { - // clone before put to received queue, because received data may change by consumer. - clone = std::make_shared(*pkco); - } - - { - std::lock_guard lock(m_receivedQueueMutex); - m_receivedOperationQueue.push(pkco); - } + std::shared_ptr clone = nullptr; + if (m_asyncDBUpdater != nullptr) + { + // clone before put to received queue, because received data may change by consumer. + clone = std::make_shared(*kco); + } - m_selectableEvent.notify(); // will release epoll + { + std::lock_guard lock(m_receivedQueueMutex); + m_receivedOperationQueue.push(kco); + } - if (m_asyncDBUpdater != nullptr) - { - m_asyncDBUpdater->update(clone); + if (m_asyncDBUpdater != nullptr) + { + m_asyncDBUpdater->update(clone); + } } + m_selectableEvent.notify(); // will release epoll } /* Get multiple pop elements */ diff --git a/common/zmqconsumerstatetable.h b/common/zmqconsumerstatetable.h index d18e5dc2..dece60bd 100644 --- a/common/zmqconsumerstatetable.h +++ b/common/zmqconsumerstatetable.h @@ -70,7 +70,7 @@ class ZmqConsumerStateTable : public Selectable, public TableBase, public ZmqMes size_t dbUpdaterQueueSize(); private: - void handleReceivedData(std::shared_ptr pkco); + void handleReceivedData(const std::vector> &kcos); std::mutex m_receivedQueueMutex; diff --git a/common/zmqproducerstatetable.cpp b/common/zmqproducerstatetable.cpp index b2afbb75..ec9396b3 100644 --- a/common/zmqproducerstatetable.cpp +++ b/common/zmqproducerstatetable.cpp @@ -58,12 +58,13 @@ void ZmqProducerStateTable::set( const string &op /*= SET_COMMAND*/, const string &prefix) { + std::vector kcos = std::vector{ + KeyOpFieldsValuesTuple{key, op, values} + }; m_zmqClient.sendMsg( - key, - values, - op, m_dbName, m_tableNameStr, + kcos, m_sendbuffer); if (m_asyncDBUpdater != nullptr) @@ -86,12 +87,13 @@ void ZmqProducerStateTable::del( const string &op /*= DEL_COMMAND*/, const string &prefix) { + std::vector kcos = std::vector{ + KeyOpFieldsValuesTuple{key, op, std::vector{}} + }; m_zmqClient.sendMsg( - key, - vector(), - op, m_dbName, m_tableNameStr, + kcos, m_sendbuffer); if (m_asyncDBUpdater != nullptr) @@ -107,16 +109,11 @@ void ZmqProducerStateTable::del( void ZmqProducerStateTable::set(const std::vector &values) { - for (const auto &value : values) - { - m_zmqClient.sendMsg( - kfvKey(value), - kfvFieldsValues(value), - SET_COMMAND, - m_dbName, - m_tableNameStr, - m_sendbuffer); - } + m_zmqClient.sendMsg( + m_dbName, + m_tableNameStr, + values, + m_sendbuffer); if (m_asyncDBUpdater != nullptr) { @@ -131,16 +128,16 @@ void ZmqProducerStateTable::set(const std::vector &value void ZmqProducerStateTable::del(const std::vector &keys) { + std::vector kcos; for (const auto &key : keys) { - m_zmqClient.sendMsg( - key, - vector(), - DEL_COMMAND, - m_dbName, - m_tableNameStr, - m_sendbuffer); + kcos.push_back(KeyOpFieldsValuesTuple{key, DEL_COMMAND, std::vector{}}); } + m_zmqClient.sendMsg( + m_dbName, + m_tableNameStr, + kcos, + m_sendbuffer); if (m_asyncDBUpdater != nullptr) { @@ -155,6 +152,25 @@ void ZmqProducerStateTable::del(const std::vector &keys) } } +void ZmqProducerStateTable::send(const std::vector &kcos) +{ + m_zmqClient.sendMsg( + m_dbName, + m_tableNameStr, + kcos, + m_sendbuffer); + + if (m_asyncDBUpdater != nullptr) + { + for (const auto &value : kcos) + { + // async write need keep data till write to DB + std::shared_ptr clone = std::make_shared(value); + m_asyncDBUpdater->update(clone); + } + } +} + size_t ZmqProducerStateTable::dbUpdaterQueueSize() { if (m_asyncDBUpdater == nullptr) diff --git a/common/zmqproducerstatetable.h b/common/zmqproducerstatetable.h index 8c784d42..74910782 100644 --- a/common/zmqproducerstatetable.h +++ b/common/zmqproducerstatetable.h @@ -34,6 +34,9 @@ class ZmqProducerStateTable : public ProducerStateTable virtual void del(const std::vector &keys); + // Batched send that can include both SET and DEL requests. + virtual void send(const std::vector &kcos); + size_t dbUpdaterQueueSize(); private: void initialize(DBConnector *db, const std::string &tableName, bool dbPersistence); diff --git a/common/zmqserver.cpp b/common/zmqserver.cpp index 2a3e3f2f..118f1609 100644 --- a/common/zmqserver.cpp +++ b/common/zmqserver.cpp @@ -14,6 +14,7 @@ namespace swss { ZmqServer::ZmqServer(const std::string& endpoint) : m_endpoint(endpoint) { + m_buffer.resize(MQ_RESPONSE_MAX_COUNT); m_mqPollThread = std::make_shared(&ZmqServer::mqPollThread, this); m_runThread = true; @@ -63,16 +64,10 @@ ZmqMessageHandler* ZmqServer::findMessageHandler( void ZmqServer::handleReceivedData(const char* buffer, const size_t size) { - auto pkco = std::make_shared(); - KeyOpFieldsValuesTuple &kco = *pkco; - auto& values = kfvFieldsValues(kco); - BinarySerializer::deserializeBuffer(buffer, size, values); - - // get table name - swss::FieldValueTuple fvt = values.at(0); - string dbName = fvField(fvt); - string tableName = fvValue(fvt); - values.erase(values.begin()); + std::string dbName; + std::string tableName; + std::vector> kcos; + BinarySerializer::deserializeBuffer(buffer, size, dbName, tableName, kcos); // find handler auto handler = findMessageHandler(dbName, tableName); @@ -81,21 +76,13 @@ void ZmqServer::handleReceivedData(const char* buffer, const size_t size) return; } - // get key and OP - fvt = values.at(0); - kfvKey(kco) = fvField(fvt); - kfvOp(kco) = fvValue(fvt); - values.erase(values.begin()); - - handler->handleReceivedData(pkco); + handler->handleReceivedData(kcos); } void ZmqServer::mqPollThread() { SWSS_LOG_ENTER(); SWSS_LOG_NOTICE("mqPollThread begin"); - std::vector buffer; - buffer.resize(MQ_RESPONSE_MAX_COUNT); // Producer/Consumer state table are n:1 mapping, so need use PUSH/PULL pattern http://api.zeromq.org/master:zmq-socket void* context = zmq_ctx_new();; @@ -133,7 +120,7 @@ void ZmqServer::mqPollThread() } // receive message - rc = zmq_recv(socket, buffer.data(), MQ_RESPONSE_MAX_COUNT, ZMQ_DONTWAIT); + rc = zmq_recv(socket, m_buffer.data(), MQ_RESPONSE_MAX_COUNT, ZMQ_DONTWAIT); if (rc < 0) { int zmq_err = zmq_errno(); @@ -155,11 +142,11 @@ void ZmqServer::mqPollThread() rc); } - buffer.at(rc) = 0; // make sure that we end string with zero before parse + m_buffer.at(rc) = 0; // make sure that we end string with zero before parse SWSS_LOG_DEBUG("zmq received %d bytes", rc); // deserialize and write to redis: - handleReceivedData(buffer.data(), rc); + handleReceivedData(m_buffer.data(), rc); } zmq_close(socket); diff --git a/common/zmqserver.h b/common/zmqserver.h index 596ab2c2..002e78b1 100644 --- a/common/zmqserver.h +++ b/common/zmqserver.h @@ -3,9 +3,10 @@ #include #include #include +#include #include "table.h" -#define MQ_RESPONSE_MAX_COUNT (4*1024*1024) +#define MQ_RESPONSE_MAX_COUNT (16*1024*1024) #define MQ_SIZE 100 #define MQ_MAX_RETRY 10 #define MQ_POLL_TIMEOUT (1000) @@ -20,7 +21,7 @@ class ZmqMessageHandler { public: virtual ~ZmqMessageHandler() {}; - virtual void handleReceivedData(std::shared_ptr pkco) = 0; + virtual void handleReceivedData(const std::vector>& kcos) = 0; }; class ZmqServer @@ -44,6 +45,8 @@ class ZmqServer ZmqMessageHandler* findMessageHandler(const std::string dbName, const std::string tableName); + std::vector m_buffer; + volatile bool m_runThread; std::shared_ptr m_mqPollThread; diff --git a/tests/binary_serializer_ut.cpp b/tests/binary_serializer_ut.cpp index 19219cb8..e6f7392b 100644 --- a/tests/binary_serializer_ut.cpp +++ b/tests/binary_serializer_ut.cpp @@ -10,43 +10,43 @@ using namespace swss; TEST(BinarySerializer, serialize_deserialize) { string test_entry_key = "test_key"; - string test_command = "test_command"; + string test_command = "SET"; string test_db = "test_db"; string test_table = "test_table"; string test_key = "key"; string test_value= "value"; + string test_entry_key2 = "test_key_2"; + string test_command2 = "DEL"; char buffer[200]; std::vector values; values.push_back(std::make_pair(test_key, test_value)); + std::vector kcos = std::vector{ + KeyOpFieldsValuesTuple{test_entry_key, test_command, values}, + KeyOpFieldsValuesTuple{test_entry_key2, test_command2, std::vector{}}}; int serialized_len = (int)BinarySerializer::serializeBuffer( buffer, sizeof(buffer), - test_entry_key, - values, - test_command, test_db, - test_table); + test_table, + kcos); string serialized_str(buffer); - EXPECT_EQ(serialized_len, 101); - - auto ptr = std::make_shared(); - KeyOpFieldsValuesTuple &kco = *ptr; - auto& deserialized_values = kfvFieldsValues(kco); - BinarySerializer::deserializeBuffer(buffer, serialized_len, deserialized_values); - - swss::FieldValueTuple fvt = deserialized_values.at(0); - EXPECT_TRUE(fvField(fvt) == test_db); - EXPECT_TRUE(fvValue(fvt) == test_table); - - fvt = deserialized_values.at(1); - EXPECT_TRUE(fvField(fvt) == test_entry_key); - EXPECT_TRUE(fvValue(fvt) == test_command); - - fvt = deserialized_values.at(2); - EXPECT_TRUE(fvField(fvt) == test_key); - EXPECT_TRUE(fvValue(fvt) == test_value); + EXPECT_EQ(serialized_len, 117); + + std::vector> kcos_ptrs; + std::vector deserialized_kcos; + string db_name; + string db_table; + BinarySerializer::deserializeBuffer(buffer, serialized_len, db_name, db_table, kcos_ptrs); + for (auto kco_ptr : kcos_ptrs) + { + deserialized_kcos.push_back(*kco_ptr); + } + + EXPECT_EQ(db_name, test_db); + EXPECT_EQ(db_table, test_table); + EXPECT_EQ(deserialized_kcos, kcos); } TEST(BinarySerializer, serialize_overflow) @@ -54,14 +54,14 @@ TEST(BinarySerializer, serialize_overflow) char buffer[50]; std::vector values; values.push_back(std::make_pair("test_key", "test_value")); + std::vector kcos = std::vector{ + KeyOpFieldsValuesTuple{"test_entry_key", "SET", values}}; EXPECT_THROW(BinarySerializer::serializeBuffer( buffer, sizeof(buffer), - "test_entry_key", - values, - "test_command", "test_db", - "test_table"), runtime_error); + "test_table", + kcos), runtime_error); } TEST(BinarySerializer, deserialize_overflow) @@ -69,26 +69,26 @@ TEST(BinarySerializer, deserialize_overflow) char buffer[200]; std::vector values; values.push_back(std::make_pair("test_key", "test_value")); + std::vector kcos = std::vector{ + KeyOpFieldsValuesTuple{"test_entry_key", "SET", values}}; int serialized_len = (int)BinarySerializer::serializeBuffer( buffer, sizeof(buffer), - "test_entry_key", - values, - "test_command", "test_db", - "test_table"); + "test_table", + kcos); string serialized_str(buffer); - auto ptr = std::make_shared(); - KeyOpFieldsValuesTuple &kco = *ptr; - auto& deserialized_values = kfvFieldsValues(kco); - EXPECT_THROW(BinarySerializer::deserializeBuffer(buffer, serialized_len - 10, deserialized_values), runtime_error); + std::vector> kcos_ptrs; + string db_name; + string db_table; + EXPECT_THROW(BinarySerializer::deserializeBuffer(buffer, serialized_len - 10, db_name, db_table, kcos_ptrs), runtime_error); } TEST(BinarySerializer, protocol_buffer) { string test_entry_key = "test_key"; - string test_command = "test_command"; + string test_command = "SET"; string test_db = "test_db"; string test_table = "test_table"; string test_key = "key"; @@ -99,33 +99,29 @@ TEST(BinarySerializer, protocol_buffer) char buffer[200]; std::vector values; values.push_back(std::make_pair(test_key, proto_buf_val)); + std::vector kcos = std::vector{ + KeyOpFieldsValuesTuple{test_entry_key, test_command, values}}; int serialized_len = (int)BinarySerializer::serializeBuffer( buffer, sizeof(buffer), - test_entry_key, - values, - test_command, test_db, - test_table); + test_table, + kcos); string serialized_str(buffer); - EXPECT_EQ(serialized_len, 106); - - auto ptr = std::make_shared(); - KeyOpFieldsValuesTuple &kco = *ptr; - auto& deserialized_values = kfvFieldsValues(kco); - BinarySerializer::deserializeBuffer(buffer, serialized_len, deserialized_values); - - swss::FieldValueTuple fvt = deserialized_values.at(0); - EXPECT_TRUE(fvField(fvt) == test_db); - EXPECT_TRUE(fvValue(fvt) == test_table); - - fvt = deserialized_values.at(1); - EXPECT_TRUE(fvField(fvt) == test_entry_key); - EXPECT_TRUE(fvValue(fvt) == test_command); - - fvt = deserialized_values.at(2); - EXPECT_TRUE(fvField(fvt) == test_key); - EXPECT_TRUE(fvValue(fvt) == proto_buf_val); - EXPECT_TRUE(fvValue(fvt).length() == sizeof(binary_proto_buf)); + EXPECT_EQ(serialized_len, 95); + + std::vector> kcos_ptrs; + std::vector deserialized_kcos; + string db_name; + string db_table; + BinarySerializer::deserializeBuffer(buffer, serialized_len, db_name, db_table, kcos_ptrs); + for (auto kco_ptr : kcos_ptrs) + { + deserialized_kcos.push_back(*kco_ptr); + } + + EXPECT_EQ(db_name, test_db); + EXPECT_EQ(db_table, test_table); + EXPECT_EQ(deserialized_kcos, kcos); } diff --git a/tests/zmq_state_ut.cpp b/tests/zmq_state_ut.cpp index 1e9d2254..4818b7fd 100644 --- a/tests/zmq_state_ut.cpp +++ b/tests/zmq_state_ut.cpp @@ -127,6 +127,74 @@ static void producerWorker(string tableName, string endpoint, bool dbPersistence cout << "Producer thread ended: " << tableName << endl; } +// Reusing the same keys as the producerWorker so that the same consumer thread +// can be used. +static void producerBatchWorker(string tableName, string endpoint, bool dbPersistence) +{ + DBConnector db(TEST_DB, 0, true); + ZmqClient client(endpoint); + ZmqProducerStateTable p(&db, tableName, client, dbPersistence); + cout << "Producer thread started: " << tableName << endl; + std::vector kcos; + + for (int i = 0; i < NUMBER_OF_OPS; i++) + { + vector fields; + for (int j = 0; j < MAX_FIELDS; j++) + { + FieldValueTuple t(field(j), value(j)); + fields.push_back(t); + } + kcos.push_back(KeyOpFieldsValuesTuple("set_key_" + to_string(i), SET_COMMAND, fields)); + } + + for (int i = 0; i < NUMBER_OF_OPS; i++) + { + kcos.push_back(KeyOpFieldsValuesTuple("del_key_" + to_string(i), DEL_COMMAND, vector{})); + } + + for (int i = 0; i < NUMBER_OF_OPS; i++) + { + for (int j = 0; j < MAX_KEYS; j++) + { + vector fields; + for (int k = 0; k < MAX_FIELDS; k++) + { + FieldValueTuple t(field(k), value(k)); + fields.push_back(t); + } + kcos.push_back(KeyOpFieldsValuesTuple("batch_set_key_" + to_string(i) + "_" + to_string(j), SET_COMMAND, fields)); + } + } + + for (int i = 0; i < NUMBER_OF_OPS; i++) + { + for (int j = 0; j < MAX_KEYS; j++) + { + kcos.push_back(KeyOpFieldsValuesTuple("batch_del_key_" + to_string(i) + "_" + to_string(j), DEL_COMMAND, vector{})); + } + } + + p.send(kcos); + + // wait all data been received by consumer + while (!allDataReceived) + { + sleep(1); + } + + if (dbPersistence) + { + // wait all persist data write to redis + while (p.dbUpdaterQueueSize() > 0) + { + sleep(1); + } + } + + cout << "Producer thread ended: " << tableName << endl; +} + // variable used by consumer worker static int setCount = 0; static int delCount = 0; @@ -203,7 +271,6 @@ static void consumerWorker(string tableName, string endpoint, bool dbPersistence cout << "Consumer thread ended: " << tableName << endl; } - static void testMethod(bool producerPersistence) { std::string testTableName = "ZMQ_PROD_CONS_UT"; @@ -243,6 +310,69 @@ static void testMethod(bool producerPersistence) EXPECT_EQ(batchSetCount, NUMBER_OF_THREADS * NUMBER_OF_OPS * MAX_KEYS); EXPECT_EQ(batchDelCount, NUMBER_OF_THREADS * NUMBER_OF_OPS * MAX_KEYS); + // check presist data in redis + DBConnector db(TEST_DB, 0, true); + Table table(&db, testTableName); + std::vector keys; + table.getKeys(keys); + setCount = 0; + batchSetCount = 0; + for (string& key : keys) + { + if (key.rfind("batch_set_key_", 0) == 0) + { + batchSetCount++; + } + else if (key.rfind("set_key_", 0) == 0) + { + setCount++; + } + } + EXPECT_EQ(setCount, NUMBER_OF_OPS); + EXPECT_EQ(batchSetCount, NUMBER_OF_OPS * MAX_KEYS); + + cout << endl << "Done." << endl; +} + +static void testBatchMethod(bool producerPersistence) +{ + std::string testTableName = "ZMQ_PROD_CONS_UT"; + std::string pushEndpoint = "tcp://localhost:1234"; + std::string pullEndpoint = "tcp://*:1234"; + thread *producerThreads[NUMBER_OF_THREADS]; + + // reset receive data counter + setCount = 0; + delCount = 0; + batchSetCount = 0; + batchDelCount = 0; + allDataReceived = false; + + // start consumer first, SHM can only have 1 consumer per table. + thread *consumerThread = new thread(consumerWorker, testTableName, pullEndpoint, !producerPersistence); + + cout << "Starting " << NUMBER_OF_THREADS << " producers" << endl; + /* Starting the producer before the producer */ + for (int i = 0; i < NUMBER_OF_THREADS; i++) + { + producerThreads[i] = new thread(producerBatchWorker, testTableName, pushEndpoint, producerPersistence); + } + + cout << "Done. Waiting for all job to finish " << NUMBER_OF_OPS << " jobs." << endl; + for (int i = 0; i < NUMBER_OF_THREADS; i++) + { + producerThreads[i]->join(); + delete producerThreads[i]; + } + + consumerThread->join(); + delete consumerThread; + + EXPECT_EQ(setCount, NUMBER_OF_THREADS * NUMBER_OF_OPS); + EXPECT_EQ(delCount, NUMBER_OF_THREADS * NUMBER_OF_OPS); + EXPECT_EQ(batchSetCount, NUMBER_OF_THREADS * NUMBER_OF_OPS * MAX_KEYS); + EXPECT_EQ(batchDelCount, NUMBER_OF_THREADS * NUMBER_OF_OPS * MAX_KEYS); + // check presist data in redis DBConnector db(TEST_DB, 0, true); Table table(&db, testTableName); @@ -273,9 +403,38 @@ TEST(ZmqConsumerStateTable, test) testMethod(false); } - TEST(ZmqProducerStateTable, test) { // test with persist by producer testMethod(true); -} \ No newline at end of file +} + +TEST(ZmqConsumerStateTableBatch, test) +{ + // test with persist by consumer + testBatchMethod(false); +} + +TEST(ZmqProducerStateTableBatch, test) +{ + // test with persist by producer + testBatchMethod(true); +} + +TEST(ZmqConsumerStateTableBatchBufferOverflow, test) +{ + std::string testTableName = "ZMQ_PROD_CONS_UT"; + std::string pushEndpoint = "tcp://localhost:1234"; + + DBConnector db(TEST_DB, 0, true); + ZmqClient client(pushEndpoint); + ZmqProducerStateTable p(&db, testTableName, client, true); + + // Send a large message and expect exception thrown. + std::vector kcos; + for (int i = 0; i <= MQ_RESPONSE_MAX_COUNT; i++) + { + kcos.push_back(KeyOpFieldsValuesTuple("key", DEL_COMMAND, vector{})); + } + EXPECT_ANY_THROW(p.send(kcos)); +}