diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 48a3e17b2..2db2cbe08 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1319,9 +1319,6 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model graph_cache_memory = 0 if self.compile: - if self.world_size > 1: - raise ValueError("Cannot enable `--compile` when sharding across multiple GPUs") - # Estimate the memory overhead from CUDA graphs so we can subtract it from the kv cache. # Needs to be estimated here rather than fully initialized as the graph cache relies on the # cache manager being set.