diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 8a67c08c8f..b8d7011b4b 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -39,7 +39,7 @@ class BatchVerifyActionObj : public EngineActionObj { draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), - rng_(RandomGenerator::GetInstance()) {} + rng_(UniformRandomGenerator::GetInstance()) {} Array Step(EngineState estate) final { // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests. diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 819b6791f3..d5bb79d351 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -39,7 +39,7 @@ class EagleBatchVerifyActionObj : public EngineActionObj { draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), - rng_(RandomGenerator::GetInstance()) {} + rng_(UniformRandomGenerator::GetInstance()) {} Array Step(EngineState estate) final { // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests. diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 17e02ee85b..48b0c07643 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -156,7 +156,7 @@ RequestStateEntry::RequestStateEntry( mstates.push_back(RequestModelState(request, i, internal_id, inputs, compiled_grammar)); } n->status = RequestStateStatus::kPending; - n->rng = RandomGenerator(rng_seed); + n->rng = UniformRandomGenerator(rng_seed); n->stop_str_handler = StopStrHandler(!request->generation_cfg->debug_config.ignore_eos ? request->generation_cfg->stop_strs : Array(), diff --git a/cpp/support/random.h b/cpp/support/random.h index 3c9b65c687..05b3925148 100644 --- a/cpp/support/random.h +++ b/cpp/support/random.h @@ -8,27 +8,67 @@ #define MLC_LLM_SUPPORT_RANDOM_H_ #include +#include namespace mlc { namespace llm { -// Random number generator +// Base class for random number generators. class RandomGenerator { + private: + int seed_; + + public: + RandomGenerator(int seed = std::random_device{}()) : seed_(seed) {} + + // Returns a random number in [0, 1). + virtual double GetRandomNumber() { + throw std::runtime_error("GetRandomNumber() not implemented"); + } + + // Returns a Philox offset based on the increment. + virtual uint64_t GetPhiloxOffset(uint64_t increment) { + throw std::runtime_error("GetPhiloxOffset() not implemented"); + } + + // Retrieves the seed. + int GetSeed() const { return seed_; } +}; + +class UniformRandomGenerator : public RandomGenerator { private: std::mt19937 gen; std::uniform_real_distribution<> dis; public: - RandomGenerator(int seed = std::random_device{}()) : gen(seed), dis(0.0, 1.0) {} + UniformRandomGenerator(int seed = std::random_device{}()) + : RandomGenerator(seed), gen(seed), dis(0.0, 1.0) {} - static RandomGenerator& GetInstance(int seed = std::random_device{}()) { - static RandomGenerator instance(seed); + static UniformRandomGenerator& GetInstance(int seed = std::random_device{}()) { + static UniformRandomGenerator instance(seed); return instance; } - double GetRandomNumber() { return dis(gen); } + double GetRandomNumber() override { return dis(gen); } +}; + +// Primarily for state tracking +class PhiloxRandomGenerator : public RandomGenerator { + private: + uint64_t offset_; + + public: + PhiloxRandomGenerator(int seed = std::random_device{}()) : RandomGenerator(seed), offset_(0) {} - void SetSeed(int seed) { gen.seed(seed); } + static PhiloxRandomGenerator& GetInstance(int seed = std::random_device{}()) { + static PhiloxRandomGenerator instance(seed); + return instance; + } + + uint64_t GetPhiloxOffset(uint64_t increment) override { + offset_ += increment; + return offset_; + } }; } // namespace llm