Skip to content

Commit

Permalink
fix wss_boost bug
Browse files Browse the repository at this point in the history
Change-Id: Ic6b492a58730085e33a47d4db474e7c76849f07b
yuchuan.he committed Oct 25, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent cc54785 commit c609b40
Showing 4 changed files with 149 additions and 129 deletions.
193 changes: 114 additions & 79 deletions element/multimedia/encode/include/wss_boost.h
Original file line number Diff line number Diff line change
@@ -36,28 +36,80 @@ namespace net = boost::asio;
using tcp = net::ip::tcp;
using json = nlohmann::json;

class FlexibleBarrier {
public:
using Callback = std::function<void()>; // 定义回调类型

FlexibleBarrier(
int count, Callback callback = [] {})
: thread_count(count), count_to_wait(count), on_completion(callback) {}

void arrive_and_wait() {
std::unique_lock<std::mutex> lock(mtx);
--count_to_wait;

if (count_to_wait == 0) {
// 执行回调函数
on_completion();

count_to_wait = thread_count; // 重置等待计数
lock.unlock();
cv.notify_all(); // 唤醒所有等待线程
} else {
cv.wait(lock);
}
}

void add_thread() {
std::lock_guard<std::mutex> lock(mtx);
++thread_count;
++count_to_wait;
}
void del_thread() {
std::lock_guard<std::mutex> lock(mtx);
--thread_count;
--count_to_wait;
}
// 允许在运行时更改回调函数
void set_on_completion(Callback callback) {
std::lock_guard<std::mutex> lock(mtx);
on_completion = callback;
}
int get_thread_count() {
std::lock_guard<std::mutex> lock(mtx);
return thread_count;
}

private:
std::mutex mtx;
std::condition_variable cv;
int thread_count; // 总线程数
int count_to_wait; // 当前需要等待的线程数
Callback on_completion; // 完成时的回调函数
};

class WebSocketServer {
public:
WebSocketServer(unsigned short port, int fps, int conns)
WebSocketServer(unsigned short port, int fps)
: ioc_(),
acceptor_(ioc_, tcp::endpoint(tcp::v4(), port)),
fps_(fps),
conns_(conns),
strand_(net::make_strand(ioc_)),
timer_(ioc_) {
barrier_ = std::make_shared<FlexibleBarrier>(0, [this]() {
std::unique_lock<std::mutex> lock(mutex_);
message_queue_.pop();
});
}

void run();

bool is_open();

void destroy();

void reconnect(int index);
void reconnect();

void pushImgDataQueue(const std::string& data);

const int getConnectionsNum() const;
int getConnectionsNum();

private:
void do_accept();
@@ -68,115 +120,98 @@ class WebSocketServer {

const int MAX_WSS_QUEUE_LENGTH = 5;

class Session {
class Session : public std::enable_shared_from_this<Session> {
public:
Session(tcp::socket socket, std::mutex& mutex, std::condition_variable& cv)
: ws_(std::move(socket)), mutex_(mutex), cv_(cv) {
Session(tcp::socket socket, std::queue<std::string>& message_queue,
std::mutex& mutex, std::condition_variable& cv,
std::shared_ptr<FlexibleBarrier>& barrier)
: ws_(std::move(socket)),
message_queue_(message_queue),
mutex_(mutex),
cv_(cv),
barrier_(barrier) {
ws_.read_message_max(64 * 1024 * 1024); // 64 MB
// ws_.write_buffer_size(64 * 1024 * 1024); // 64 MB
}

Session(const Session&) = delete;
Session& operator=(const Session&) = delete;
~Session() {
while (!queue_.empty()) {
queue_.pop();
futureObj.wait();
if (writeThread_.joinable()) {
writeThread_.join();
}
std::cout << "Session released." << std::endl;
}

std::queue<std::string> queue_; // 要发送的消息队列
static bool shouldExit_;

void run() {

boost::system::error_code ec;
ws_.accept(ec);
if (!ec) {
writeThread_ = std::thread(&Session::do_write, this);
writeThread_.detach();
std::promise<void> promiseObj;
futureObj = promiseObj.get_future();
writeThread_ = std::thread(&Session::do_write, shared_from_this(),
std::move(promiseObj));
}
}

void stop() {
close();
shouldExit_ = true;
}

void setCallback(std::function<void(int)> cb) { callback = cb; }

void invokeCallback(int param) {
callback(param);
}

bool is_open() { return ws_.is_open(); }

void close() {
boost::system::error_code ec;
ws_.close(websocket::close_code::normal, ec);
}

bool should_des() {
if (rflag && wflag) return true;
return false;
}
void do_write(std::promise<void> promiseObj) {
barrier_->add_thread();
while (!shouldStop) {
std::string message;
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return !message_queue_.empty(); });

void do_read() {
while (is_open() && !shouldExit_) {
boost::system::error_code ec;
std::size_t bytes_transferred = ws_.read(buffer_, ec);
if (ec == boost::system::errc::connection_reset) {
if (is_open()) {
stop();
}
break;
} else if (ec) {
std::cerr << "Error reading data: " << ec.message() << std::endl;
if (is_open()) {
stop();
}
break;
message = message_queue_.front();
}
}
rflag = true;
}

void do_write() {
while (is_open() && !shouldExit_) {
std::unique_lock<std::mutex> lock(mutex_);
// cv_.wait(lock, [this] { return !queue_.empty(); });
cv_.wait(lock, [this] { return !queue_.empty();});
std::string message = queue_.front();
queue_.pop();
std::vector<uint8_t> binary_message(message.begin(), message.end());
ws_.binary(true);
try {
ws_.write(net::buffer(binary_message));
} catch (boost::wrapexcept<boost::system::system_error>& ex) {
if (is_open()) {
stop();
}
barrier_->arrive_and_wait();

ws_.text(true);


boost::system::error_code ec;
// Synchronously write the message to the WebSocket
ws_.write(net::buffer(message), ec);

// Check if the operation succeeded
if (!ec) {

} else {
Session::shouldExit_ = true;
barrier_->del_thread();
close();
break;
}
}
wflag = true;
promiseObj.set_value();
}

private:
/**
* @brief 从message_queue_中取出消息,异步发送。发送完成后递归调用自身
*
*/
bool rflag = false;
bool wflag = false;
bool shouldExit_ = false;
std::function<void(int)> callback; // 回调函数指针
std::thread readThread_, writeThread_;
websocket::stream<tcp::socket> ws_; // websocket会话的流对象
beast::flat_buffer buffer_; // 从客户端接收到的数据
std::thread writeThread_;
websocket::stream<tcp::socket> ws_; // websocket会话的流对象
beast::flat_buffer buffer_; // 从客户端接收到的数据
std::queue<std::string>& message_queue_; // 要发送的消息队列
std::mutex& mutex_;
std::atomic<bool> shouldStop;
std::condition_variable& cv_;
bool writing_ = false;
int index;
std::future<void> futureObj;
std::shared_ptr<FlexibleBarrier> barrier_; // std::thread writeThread_;
};
int num = 0;
int conns_;

net::io_context ioc_; // 管理IO上下文
tcp::acceptor acceptor_; // 侦听传入的连接请求,创建新的tcp::socket
int fps_;
@@ -187,8 +222,8 @@ class WebSocketServer {
std::mutex mutex_;
std::mutex ws_mutex;
std::condition_variable cv_;
// std::vector<std::shared_ptr<Session>> sessions_;
std::vector<Session*> sessions_;
std::vector<std::shared_ptr<Session>> sessions_;
std::shared_ptr<FlexibleBarrier> barrier_;
};

} // namespace encode
14 changes: 10 additions & 4 deletions element/multimedia/encode/src/encode.cc
Original file line number Diff line number Diff line change
@@ -407,7 +407,7 @@ void Encode::processWS(int dataPipeId,
mWSSMap[dataPipeId] = std::make_shared<WSSManager>(wss);
serverIt = mWSSMap.find(dataPipeId);
} else if (mWssBackend == WSSBackend::BOOST) {
auto wss = std::make_shared<WebSocketServer>(server_port, mFps, 4);
auto wss = std::make_shared<WebSocketServer>(server_port, mFps);
std::thread t([wss]() { wss->run(); });
std::lock_guard<std::mutex> lk(mWSSThreadsMutex);
mWSSThreads.push_back(std::move(t));
@@ -419,7 +419,7 @@ void Encode::processWS(int dataPipeId,
if (!serverIt->second->getConnectionsNum()) {
return;
}

std::string data;
if (mWsEncType == WSencType::IMG_ONLY) {
void* jpeg_data = NULL;
@@ -443,6 +443,7 @@ void Encode::processWS(int dataPipeId,
bm_image_create(objectMetadata->mFrame->mHandle, height, width,
FORMAT_YUV420P, image.data_type, &(*img_to_enc));
bmcv_rect_t crop_rect = {0, 0, img->width, img->height};
bm_image_alloc_dev_mem(*img_to_enc, 1);
bmcv_image_vpp_convert(objectMetadata->mFrame->mHandle, 1, *img,
img_to_enc.get(), &crop_rect);
} else {
@@ -451,8 +452,13 @@ void Encode::processWS(int dataPipeId,

bmcv_image_jpeg_enc(objectMetadata->mFrame->mHandle, 1, img_to_enc.get(),
&jpeg_data, &out_size);
data =
websocketpp::base64_encode((const unsigned char*)jpeg_data, out_size);
#if BASE64_CPU
// for cpu
data = common::base64_encode((const unsigned char*)jpeg_data, out_size);
#else
data = common::base64_encode_bmcv(objectMetadata->mFrame->mHandle,
(unsigned char*)jpeg_data, out_size);
#endif
free(jpeg_data);
}
if (mWsEncType == WSencType::SERIALIZED) {
46 changes: 11 additions & 35 deletions element/multimedia/encode/src/wss_boost.cc
Original file line number Diff line number Diff line change
@@ -3,8 +3,7 @@
namespace sophon_stream {
namespace element {
namespace encode {

int idx = 0;
bool WebSocketServer::Session::shouldExit_ = false;

void WebSocketServer::run() {
do_accept();
@@ -14,53 +13,30 @@ void WebSocketServer::run() {
void WebSocketServer::pushImgDataQueue(const std::string& data) {
std::lock_guard<std::mutex> lock(mutex_);
if (message_queue_.size() < MAX_WSS_QUEUE_LENGTH) {
for (auto& session_ : sessions_) {
if (session_->is_open()) {
session_->queue_.push(data);
}
}
message_queue_.push(data);

cv_.notify_all();
}
}



void WebSocketServer::do_accept() {
int num = 0;
while (1) {
tcp::socket socket(ioc_);

acceptor_.accept(socket);
if (sessions_.size() == 4) {
cv_.notify_all();
for (int i = 0; i < 4; ++i) {
delete sessions_[i];
sessions_[i] = nullptr;
}
sessions_.clear();
}

auto session_ = new Session(std::move(socket), mutex_, cv_);
sessions_.push_back(session_);
sessions_.back()->run();
std::cout << sessions_.size() << std::endl;
auto session_ = std::make_shared<Session>(std::move(socket), message_queue_,
mutex_, cv_, barrier_);
session_->run();
}
}

const int WebSocketServer::getConnectionsNum() const {
return sessions_.size();
int WebSocketServer::getConnectionsNum() {
return barrier_->get_thread_count();
}

void WebSocketServer::reconnect(int index) {
// std::cout << sessions_.back().use_count()<<std::endl;
// for(int i = 0; i < 2; ++i) {
// if(sessions_[i]->is_open())
// sessions_[i]->close();
// delete sessions_[i];
// sessions_[i] = nullptr;
// }
// ioc_.restart();
std::cout << index << std::endl;
if (sessions_[index]->is_open()) sessions_[index]->close();
sessions_[index] = nullptr;
}
} // namespace encode
} // namespace element
} // namespace sophon_stream
25 changes: 14 additions & 11 deletions framework/common/serialize.h
Original file line number Diff line number Diff line change
@@ -128,7 +128,17 @@ std::string base64_encode(unsigned char const* input, size_t len) {

return ret;
}

std::string base64_encode_bmcv(bm_handle_t handle_, unsigned char* jpegData,
size_t nBytes) {
// for bmcv
unsigned long origin_len[2] = {nBytes, 0};
unsigned long encode_len[2] = {(origin_len[0] + 2) / 3 * 4, 0};
std::string res(encode_len[0], '\0');
bmcv_base64_enc(handle_, bm_mem_from_system(jpegData),
bm_mem_from_system(const_cast<char*>(res.c_str())),
origin_len);
return res;
}
std::string frame_to_base64(Frame& frame) {
#if ENABLE_TIME_LOG
timeval time1, time2, time3, time4, time5;
@@ -137,9 +147,9 @@ std::string frame_to_base64(Frame& frame) {
unsigned char* jpegData = nullptr;
size_t nBytes = 0;
bm_image bgr_;
if(frame.mSpDataOsd!=nullptr){
if (frame.mSpDataOsd != nullptr) {
bgr_ = *(frame.mSpDataOsd);
}else{
} else {
bgr_ = *(frame.mSpData);
}
bm_handle_t handle_ = bm_image_get_handle(&bgr_);
@@ -163,18 +173,11 @@ std::string frame_to_base64(Frame& frame) {
gettimeofday(&time4, NULL);
#endif
bm_image_destroy(yuv_);

#if BASE64_CPU
// for cpu
std::string res = base64_encode(jpegData, nBytes);
#else
// for bmcv
unsigned long origin_len[2] = {nBytes, 0};
unsigned long encode_len[2] = {(origin_len[0] + 2) / 3 * 4, 0};
std::string res(encode_len[0], '\0');
bmcv_base64_enc(handle_, bm_mem_from_system(jpegData),
bm_mem_from_system(const_cast<char*>(res.c_str())),
origin_len);
std::string res = base64_encode_bmcv(handle_, jpegData, nBytes);
#endif

#if ENABLE_TIME_LOG

0 comments on commit c609b40

Please sign in to comment.