Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Graceful termination of background threads in LlamaV2 #458

Merged
merged 3 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
template<typename T>
LlamaV2<T>::~LlamaV2()
{
shared_state_->request_queue.close();
internal_thread_.join();

delete decoder_;
Expand Down Expand Up @@ -448,12 +449,24 @@ void LlamaV2<T>::internalThreadEntry(int device_id)

request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty);

// request queue was closed
// and there are no unprocessed requests in the queue
if (is_empty && infer_requests.empty() && stop_requests.empty()) {
// rank 0 sets flag
shared_state_->should_stop = true;
}

batch_.verifyRequests(stop_requests, infer_requests);
}

// wait while rank-0 is dequeueing
shared_state_->barrier->wait();

// exit if job is done
if (shared_state_->should_stop) {
return;
}

bool modified = false;

if (!(batch_.finishedCount() == 0 && stop_requests.empty() && infer_requests.empty())) {
Expand Down Expand Up @@ -486,8 +499,6 @@ void LlamaV2<T>::internalThreadEntry(int device_id)
batch_.finish();
}
}

FT_CHECK(0);
}

template<typename T>
Expand Down
3 changes: 3 additions & 0 deletions src/turbomind/models/llama/LlamaV2.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class LlamaV2 {
std::vector<std::shared_ptr<Request>> stop_requests;
RequestQueue request_queue;
std::shared_ptr<Barrier> barrier;

// rank 0 sets flag to true if there are no more tasks in the request_queue
bool should_stop = false;
};

~LlamaV2();
Expand Down
15 changes: 14 additions & 1 deletion src/turbomind/models/llama/Request.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class RequestQueue {
futures.reserve(requests.size());
{
std::lock_guard<std::mutex> lock(mutex_);

if (closed_) {
throw std::runtime_error("Queue is closed");
}

for (auto& r : requests) {
futures.push_back(r->signal.get_future());
if (r->stop_flag) {
Expand All @@ -65,7 +70,7 @@ class RequestQueue {
{
std::unique_lock<std::mutex> lock(mutex_);
if (blocking) {
cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty()); });
cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty() && closed_ == false); });
}

stop_requests.clear();
Expand All @@ -81,11 +86,19 @@ class RequestQueue {
}
}

void close()
{
std::lock_guard<std::mutex> lock(mutex_);
closed_ = true;
cv_.notify_all();
}

private:
std::queue<std::shared_ptr<Request>> stop_queue_;
std::queue<std::shared_ptr<Request>> infer_queue_;
std::mutex mutex_;
std::condition_variable cv_;
bool closed_ = false;
};

} // namespace turbomind
31 changes: 20 additions & 11 deletions src/turbomind/utils/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,15 @@ class Allocator;
template<>
class Allocator<AllocatorType::CUDA>: public IAllocator {
private:
const int device_id_;
cudaStream_t stream_ = 0; // initialize as default stream
std::unordered_map<void*, size_t>* pointer_mapping_;
enum class MemoryType
{
HOST,
DEVICE
};

const int device_id_;
cudaStream_t stream_ = 0; // initialize as default stream
std::unordered_map<void*, std::pair<size_t, MemoryType>>* pointer_mapping_;

bool isExist(void* address) const
{
Expand All @@ -136,10 +142,10 @@ class Allocator<AllocatorType::CUDA>: public IAllocator {
ReallocType isReMalloc(void* address, size_t size) const
{
FT_CHECK(isExist(address));
if (pointer_mapping_->at(address) < size) {
if (pointer_mapping_->at(address).first < size) {
return ReallocType::INCREASE;
}
else if (pointer_mapping_->at(address) == size) {
else if (pointer_mapping_->at(address).first == size) {
return ReallocType::REUSE;
}
else {
Expand All @@ -151,7 +157,7 @@ class Allocator<AllocatorType::CUDA>: public IAllocator {
Allocator(int device_id): device_id_(device_id)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
pointer_mapping_ = new std::unordered_map<void*, size_t>();
pointer_mapping_ = new std::unordered_map<void*, std::pair<size_t, MemoryType>>();
#if defined(CUDA_MEMORY_POOL_DISABLED)
TM_LOG_WARNING(
"Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free."
Expand Down Expand Up @@ -188,7 +194,9 @@ class Allocator<AllocatorType::CUDA>: public IAllocator {
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
while (!pointer_mapping_->empty()) {
free((void**)(&pointer_mapping_->begin()->first));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. it is UB to change key of unordered map
  2. memory type information (host or device) was missed

auto ptr = pointer_mapping_->begin()->first;
auto size_and_type = pointer_mapping_->begin()->second;
free(&ptr, size_and_type.second == MemoryType::HOST);
}
delete pointer_mapping_;
}
Expand Down Expand Up @@ -229,18 +237,19 @@ class Allocator<AllocatorType::CUDA>: public IAllocator {
check_cuda_error(getSetDevice(o_device));
TM_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size);

pointer_mapping_->insert({getAddress(ptr), size});
pointer_mapping_->insert({getAddress(ptr), {size, is_host ? MemoryType::HOST : MemoryType::DEVICE}});

return ptr;
}

void free(void** ptr, bool is_host = false) const
void free(void** ptr, bool _ = false) const
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
void* address = getAddress(*ptr);
if (*ptr != nullptr) {
int o_device = 0;
if (pointer_mapping_->count(address)) {
const auto is_host = pointer_mapping_->at(address).second == MemoryType::HOST;
TM_LOG_DEBUG("Free buffer %p", address);
check_cuda_error(getSetDevice(device_id_, &o_device));
if (is_host) {
Expand Down Expand Up @@ -361,7 +370,7 @@ class Allocator<AllocatorType::TF>: public IAllocator {
{
while (!pointer_mapping_->empty()) {
void* ptr = pointer_mapping_->begin()->second.flat<uint8>().data();
free((void**)(&ptr));
free(&ptr);
}
pointer_mapping_->clear();
delete pointer_mapping_;
Expand Down Expand Up @@ -454,7 +463,7 @@ class Allocator<AllocatorType::TH>: public IAllocator {
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
while (!pointer_mapping_->empty()) {
void* ptr = pointer_mapping_->begin()->second.data_ptr();
free((void**)(&ptr));
free(&ptr);
}
pointer_mapping_->clear();
delete pointer_mapping_;
Expand Down
Loading