From b5d9a39b7b914a5487e274d8792bcde49753d3d2 Mon Sep 17 00:00:00 2001 From: Annanya Date: Tue, 18 Mar 2025 13:45:59 -0400 Subject: [PATCH 1/2] Refactored random.h to have PhiloxRandomGenerator --- cpp/support/random.h | 47 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/cpp/support/random.h b/cpp/support/random.h index 3c9b65c687..57cb6e424d 100644 --- a/cpp/support/random.h +++ b/cpp/support/random.h @@ -8,27 +8,62 @@ #define MLC_LLM_SUPPORT_RANDOM_H_ #include +#include namespace mlc { namespace llm { -// Random number generator +// Base class for random number generators. class RandomGenerator { private: - std::mt19937 gen; - std::uniform_real_distribution<> dis; + int seed_; public: - RandomGenerator(int seed = std::random_device{}()) : gen(seed), dis(0.0, 1.0) {} + RandomGenerator(int seed = std::random_device{}()) : seed_(seed) {} static RandomGenerator& GetInstance(int seed = std::random_device{}()) { static RandomGenerator instance(seed); return instance; } - double GetRandomNumber() { return dis(gen); } + // Returns a random number in [0, 1). + virtual double GetRandomNumber() { + throw std::runtime_error("GetRandomNumber() not implemented"); + } + + // Returns a Philox offset based on the increment. + virtual uint64_t GetPhiloxOffset(uint64_t increment) { + throw std::runtime_error("GetPhiloxOffset() not implemented"); + } + + // Retrieves the seed. + int GetSeed() const { return seed_; } +}; + +class UniformRandomGenerator : public RandomGenerator { + private: + std::mt19937 gen; + std::uniform_real_distribution<> dis; + + public: + UniformRandomGenerator(int seed = std::random_device{}()) + : RandomGenerator(seed), gen(seed), dis(0.0, 1.0) {} + + double GetRandomNumber() override { return dis(gen); } +}; + +// Primarily for state tracking +class PhiloxRandomGenerator : public RandomGenerator { + private: + uint64_t offset_; - void SetSeed(int seed) { gen.seed(seed); } + public: + PhiloxRandomGenerator(int seed = std::random_device{}()) : RandomGenerator(seed), offset_(0) {} + + uint64_t GetPhiloxOffset(uint64_t increment) override { + offset_ += increment; + return offset_; + } }; } // namespace llm From be82edcd498d1eb3fe41c7e3fb0d280a05182ce3 Mon Sep 17 00:00:00 2001 From: Annanya Date: Sun, 6 Apr 2025 19:58:07 -0400 Subject: [PATCH 2/2] Changed dependencies --- cpp/serve/engine_actions/batch_verify.cc | 2 +- cpp/serve/engine_actions/eagle_batch_verify.cc | 2 +- cpp/serve/request_state.cc | 2 +- cpp/support/random.h | 15 ++++++++++----- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 8a67c08c8f..b8d7011b4b 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -39,7 +39,7 @@ class BatchVerifyActionObj : public EngineActionObj { draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), - rng_(RandomGenerator::GetInstance()) {} + rng_(UniformRandomGenerator::GetInstance()) {} Array Step(EngineState estate) final { // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests. diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 819b6791f3..d5bb79d351 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -39,7 +39,7 @@ class EagleBatchVerifyActionObj : public EngineActionObj { draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), - rng_(RandomGenerator::GetInstance()) {} + rng_(UniformRandomGenerator::GetInstance()) {} Array Step(EngineState estate) final { // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests. diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 17e02ee85b..48b0c07643 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -156,7 +156,7 @@ RequestStateEntry::RequestStateEntry( mstates.push_back(RequestModelState(request, i, internal_id, inputs, compiled_grammar)); } n->status = RequestStateStatus::kPending; - n->rng = RandomGenerator(rng_seed); + n->rng = UniformRandomGenerator(rng_seed); n->stop_str_handler = StopStrHandler(!request->generation_cfg->debug_config.ignore_eos ? request->generation_cfg->stop_strs : Array(), diff --git a/cpp/support/random.h b/cpp/support/random.h index 57cb6e424d..05b3925148 100644 --- a/cpp/support/random.h +++ b/cpp/support/random.h @@ -21,11 +21,6 @@ class RandomGenerator { public: RandomGenerator(int seed = std::random_device{}()) : seed_(seed) {} - static RandomGenerator& GetInstance(int seed = std::random_device{}()) { - static RandomGenerator instance(seed); - return instance; - } - // Returns a random number in [0, 1). virtual double GetRandomNumber() { throw std::runtime_error("GetRandomNumber() not implemented"); @@ -49,6 +44,11 @@ class UniformRandomGenerator : public RandomGenerator { UniformRandomGenerator(int seed = std::random_device{}()) : RandomGenerator(seed), gen(seed), dis(0.0, 1.0) {} + static UniformRandomGenerator& GetInstance(int seed = std::random_device{}()) { + static UniformRandomGenerator instance(seed); + return instance; + } + double GetRandomNumber() override { return dis(gen); } }; @@ -60,6 +60,11 @@ class PhiloxRandomGenerator : public RandomGenerator { public: PhiloxRandomGenerator(int seed = std::random_device{}()) : RandomGenerator(seed), offset_(0) {} + static PhiloxRandomGenerator& GetInstance(int seed = std::random_device{}()) { + static PhiloxRandomGenerator instance(seed); + return instance; + } + uint64_t GetPhiloxOffset(uint64_t increment) override { offset_ += increment; return offset_;