Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
7ba8356
added fallback for when can't load model from huggingface
finbarrtimbers Sep 4, 2025
e7e90bb
Added new script that runs multiple models
finbarrtimbers Aug 26, 2025
7229b8d
Updated script to pass hf token in.
finbarrtimbers Aug 27, 2025
c108e8a
Updated script.
finbarrtimbers Aug 27, 2025
f99f808
Ran benchmarks.
finbarrtimbers Aug 27, 2025
6ee6d94
Commit changes.
finbarrtimbers Aug 27, 2025
0c0a8b7
Removed flashinfer
finbarrtimbers Aug 27, 2025
c48a6da
Removed flash infer
finbarrtimbers Aug 27, 2025
938def2
Fixed typo.
finbarrtimbers Aug 27, 2025
615d25a
Changed image.
finbarrtimbers Aug 28, 2025
2148cc8
Fixed script.
finbarrtimbers Aug 28, 2025
ce3e3ab
Removed uv snc
finbarrtimbers Aug 28, 2025
981f501
Fix for olmo29
finbarrtimbers Aug 28, 2025
f71e4f9
Changed batch size
finbarrtimbers Aug 30, 2025
64870f6
Fix bug
finbarrtimbers Aug 31, 2025
e380a7c
Return batch size
finbarrtimbers Aug 31, 2025
e05419d
Apply linter formatting after merge
finbarrtimbers Sep 4, 2025
fba1614
Changes to clean up benchmark.
finbarrtimbers Sep 4, 2025
0336bcf
Changed image.
finbarrtimbers Sep 4, 2025
ede0fae
trying a change
finbarrtimbers Sep 4, 2025
61ce665
Changed beaker image
finbarrtimbers Sep 4, 2025
7aec1f0
Update code
finbarrtimbers Sep 4, 2025
c86902d
Updated lock file.
finbarrtimbers Sep 4, 2025
bf49492
Updated pyproject.toml.
finbarrtimbers Sep 4, 2025
47804fc
Now, use the right branch, with our own image.
finbarrtimbers Sep 5, 2025
ef67041
Specified tokenizer
finbarrtimbers Sep 5, 2025
ccfc64d
Add weka
finbarrtimbers Sep 5, 2025
0416067
Updated script
finbarrtimbers Sep 5, 2025
58bcddf
Removed gantry script. We should use mason instead.
finbarrtimbers Sep 5, 2025
5a151c5
Added script to launch remotely using mason.
finbarrtimbers Sep 5, 2025
45b9358
Updated script to use newly built image.
finbarrtimbers Sep 5, 2025
5ad3d4c
Added a small tweak.
finbarrtimbers Sep 5, 2025
fa4f7d0
Updated dockerignore
finbarrtimbers Sep 5, 2025
1070abb
Updated script.
finbarrtimbers Sep 6, 2025
b21ca36
Fixed config
finbarrtimbers Sep 7, 2025
565a09b
Updated steps
finbarrtimbers Sep 7, 2025
41fe817
Modified lock file
finbarrtimbers Sep 7, 2025
822fe42
Now, use HF config.
finbarrtimbers Sep 8, 2025
128e5d9
Updated code
finbarrtimbers Sep 8, 2025
0993917
Chnaged deps according to Tyler's instructions.
finbarrtimbers Sep 8, 2025
95753cb
Updated pyproject.toml.
finbarrtimbers Sep 8, 2025
dc4c806
Removed flash infer.
finbarrtimbers Sep 9, 2025
6b3d94c
benchmark runs
finbarrtimbers Sep 11, 2025
56d9e3f
Updated benchmark; now runs.
finbarrtimbers Sep 15, 2025
795f219
Undid changes to pyproject.toml.
finbarrtimbers Sep 15, 2025
a2ac0aa
Undid changes to mason.py
finbarrtimbers Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ dist/
build/
*.egg
local_dataset_cache/

benchmark_cache/
output/

# Virtual environments
.venv/
Expand Down
157 changes: 115 additions & 42 deletions open_instruct/benchmark_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,20 @@
import vllm
from ray.util import queue as ray_queue

<<<<<<< HEAD
from open_instruct import dataset_transformation, grpo_fast, logger_utils, model_utils, utils, vllm_utils3
from open_instruct.actor_manager import ActorManager
=======
from open_instruct import (
actor_manager,
dataset_transformation,
grpo_fast,
logger_utils,
model_utils,
utils,
vllm_utils3,
)
>>>>>>> abc270c4 (Changes to clean up benchmark.)
from open_instruct.queue_types import PromptRequest

# For FLOPS, we assume bf16 and ignore sparsity.
Expand Down Expand Up @@ -534,16 +546,38 @@ def memory_bytes(
return total


def load_model_dims(model_name: str) -> ModelDims:
cfg = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=True)
return ModelDims(
num_layers=cfg.num_hidden_layers,
hidden_size=cfg.hidden_size,
intermediate_size=cfg.intermediate_size,
vocab_size=cfg.vocab_size,
num_attn_heads=cfg.num_attention_heads,
num_kv_heads=getattr(cfg, "num_key_value_heads", None),
)
DEFAULT_SENTINEL = object()


def maybe_get_attribute(cfg: transformers.AutoConfig, attr_names: list[str], default=DEFAULT_SENTINEL) -> Any:
"""Get the first matching attribute from cfg."""
for name in attr_names:
if hasattr(cfg, name):
return getattr(cfg, name)
if default is not DEFAULT_SENTINEL:
return default
raise ValueError(f"None of the attributes {attr_names} found in config.")


def load_model_dims(model_name: str) -> Optional[ModelDims]:
try:
cfg = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=True)
logger.info(f"HF config is: {cfg}.")
model_dims = ModelDims(
num_layers=maybe_get_attribute(cfg, ["num_hidden_layers", "n_layers"]),
hidden_size=maybe_get_attribute(cfg, ["hidden_size", "dim"]),
intermediate_size=maybe_get_attribute(cfg, ["intermediate_size"], default=None),
vocab_size=cfg.vocab_size,
num_attn_heads=cfg.num_attention_heads,
num_kv_heads=getattr(cfg, "num_key_value_heads", None),
)
if model_dims.intermediate_size is None:
model_dims.intermediate_size = 4 * model_dims.hidden_size
return model_dims
except Exception as e:
logger.warning(f"Could not load model config from Hugging Face for '{model_name}': {e}")
logger.warning("MFU and MBU calculations will not be available")
return None


def get_device_name(device_name: str) -> str:
Expand Down Expand Up @@ -592,7 +626,7 @@ def setup_dataset(args: grpo_fast.Args, tokenizer_config: dataset_transformation

def setup_vllm_engines(
args: grpo_fast.Args, model_config: model_utils.ModelConfig, max_model_len: int = 20480
) -> tuple[list[ray.actor.ActorHandle], ray_queue.Queue, ray_queue.Queue]:
) -> tuple[list[ray.actor.ActorHandle], ray_queue.Queue, ray_queue.Queue, ray.actor.ActorHandle]:
"""Set up vLLM engines and queues."""
logger.info("Setting up vLLM engines...")

Expand All @@ -608,8 +642,8 @@ def setup_vllm_engines(
param_prompt_Q = ray_queue.Queue(maxsize=10)
inference_results_Q = ray_queue.Queue(maxsize=10)

queues_to_monitor = {"Param Prompt Queue": param_prompt_Q, "Inference Results Queue": inference_results_Q}
actor_manager = ray.remote(ActorManager).remote(queues_to_monitor, args)
queues_to_monitor = {"param_prompt_Q": param_prompt_Q, "inference_results_Q": inference_results_Q}
actor_manager_remote = ray.remote(actor_manager.ActorManager).remote(queues_to_monitor, args)

vllm_engines = vllm_utils3.create_vllm_engines(
num_engines=args.vllm_num_engines,
Expand All @@ -628,12 +662,12 @@ def setup_vllm_engines(
max_tool_calls=[0],
prompt_queue=param_prompt_Q,
results_queue=inference_results_Q,
actor_manager=actor_manager,
actor_manager=actor_manager_remote,
)

logger.info("vLLM engines ready")

return vllm_engines, param_prompt_Q, inference_results_Q, actor_manager
return vllm_engines, param_prompt_Q, inference_results_Q, actor_manager_remote


def generate_thread(vllm_engines: list[ray.actor.ActorHandle], stop_event: threading.Event) -> None:
Expand Down Expand Up @@ -782,32 +816,42 @@ def run_benchmark(
prompt_lengths = [len(prompt) for prompt in prompts]
response_lengths = [len(response) for response in result.responses]

# Calculate total FLOPs for all prompts and responses in the batch
# No need to expand prompt_lengths - the flops method now handles samples_per_prompt
model_flops = model_dims.flops(
prompt_lengths, response_lengths, samples_per_prompt=args.num_samples_per_prompt_rollout
)

# MFU = (FLOPs / time) / peak_FLOPS * 100
model_flops_per_second = model_flops / batch_generation_time if batch_generation_time > 0 else 0
result_dict["mfu"] = 100 * model_flops_per_second / device_flops

# Calculate total memory bytes for all prompts and responses in the batch
model_memory_bytes = model_dims.memory_bytes(
prompt_lengths, response_lengths, samples_per_prompt=args.num_samples_per_prompt_rollout
)

# MBU = (Memory bytes / time) / peak_bandwidth * 100
model_bytes_per_second = model_memory_bytes / batch_generation_time if batch_generation_time > 0 else 0
result_dict["mbu"] = 100 * model_bytes_per_second / device_memory_bandwidth
# Calculate MFU and MBU if model dimensions are available
if model_dims is not None:
# Calculate total FLOPs for all prompts and responses in the batch
# No need to expand prompt_lengths - the flops method now handles samples_per_prompt
model_flops = model_dims.flops(
prompt_lengths, response_lengths, samples_per_prompt=args.num_samples_per_prompt_rollout
)

# MFU = (FLOPs / time) / peak_FLOPS * 100
model_flops_per_second = model_flops / batch_generation_time if batch_generation_time > 0 else 0
result_dict["mfu"] = 100 * model_flops_per_second / device_flops

# Calculate total memory bytes for all prompts and responses in the batch
model_memory_bytes = model_dims.memory_bytes(
prompt_lengths, response_lengths, samples_per_prompt=args.num_samples_per_prompt_rollout
)

# MBU = (Memory bytes / time) / peak_bandwidth * 100
model_bytes_per_second = model_memory_bytes / batch_generation_time if batch_generation_time > 0 else 0
result_dict["mbu"] = 100 * model_bytes_per_second / device_memory_bandwidth
else:
# Model dimensions not available - set to n/a
result_dict["mfu"] = "n/a"
result_dict["mbu"] = "n/a"

save_completion_lengths([result_dict], timestamp, batch_idx)
results.append(result_dict)
# Format MFU and MBU values for display
mfu_display = f"{result_dict['mfu']:.2f}%" if result_dict["mfu"] != "n/a" else result_dict["mfu"]
mbu_display = f"{result_dict['mbu']:.2f}%" if result_dict["mbu"] != "n/a" else result_dict["mbu"]

logger.info(
f"Batch {batch_idx}/{num_batches - 1}: "
f"{result_dict['tokens_per_second']:.2f} new tokens/sec, "
f"MFU: {result_dict['mfu']:.2f}%, "
f"MBU: {result_dict['mbu']:.2f}%, "
f"MFU: {mfu_display}, "
f"MBU: {mbu_display}, "
f"generation time: {batch_generation_time:.2f}s, "
f"total new tokens: {new_tokens}"
)
Expand Down Expand Up @@ -836,12 +880,23 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
"response_lengths": [],
"prompt_lengths": [],
}

# Track if any MFU or MBU values are n/a
mfu_has_na = False
mbu_has_na = False

for result in results:
for key, value in result.items():
if key == "mfu":
aggregated_results["total_mfu"] += value
if value == "n/a":
mfu_has_na = True
else:
aggregated_results["total_mfu"] += value
elif key == "mbu":
aggregated_results["total_mbu"] += value
if value == "n/a":
mbu_has_na = True
else:
aggregated_results["total_mbu"] += value
elif key == "tokens_per_second":
aggregated_results["total_tokens_per_second"] += value
elif key == "generation_time":
Expand All @@ -860,8 +915,18 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
if aggregated_results["total_generation_time"] > 0
else 0
)
aggregated_results["avg_mfu"] = aggregated_results["total_mfu"] / num_results
aggregated_results["avg_mbu"] = aggregated_results["total_mbu"] / num_results

# Set averages to n/a if any individual result was n/a
if mfu_has_na:
aggregated_results["avg_mfu"] = "n/a"
else:
aggregated_results["avg_mfu"] = aggregated_results["total_mfu"] / num_results

if mbu_has_na:
aggregated_results["avg_mbu"] = "n/a"
else:
aggregated_results["avg_mbu"] = aggregated_results["total_mbu"] / num_results

aggregated_results["avg_generation_time"] = aggregated_results["total_generation_time"] / num_results
return aggregated_results

Expand Down Expand Up @@ -890,8 +955,13 @@ def print_summary(
print("-" * 60)
print(f"Average results over {len(results)} main benchmark batches:")
print(f"Average tokens/second: {agg_results['avg_tokens_per_second']:.2f}")
print(f"Average MFU: {agg_results['avg_mfu']:.2f}%")
print(f"Average MBU: {agg_results['avg_mbu']:.2f}%")

# Format MFU and MBU for display
avg_mfu_display = f"{agg_results['avg_mfu']:.2f}%" if agg_results["avg_mfu"] != "n/a" else agg_results["avg_mfu"]
avg_mbu_display = f"{agg_results['avg_mbu']:.2f}%" if agg_results["avg_mbu"] != "n/a" else agg_results["avg_mbu"]

print(f"Average MFU: {avg_mfu_display}")
print(f"Average MBU: {avg_mbu_display}")
print(f"Average generation time per batch: {agg_results['avg_generation_time']:.2f}s")
print(f"Average new tokens per sample: {avg_new_tokens_per_sample:.2f} tokens")

Expand Down Expand Up @@ -968,7 +1038,10 @@ def main() -> None:
free_all_gpu_memory()

dataset = setup_dataset(args, tokenizer_config)
vllm_engines, param_prompt_Q, inference_results_Q, actor_manager = setup_vllm_engines(args, model_config)
max_model_len = args.max_prompt_token_length + args.response_length
vllm_engines, param_prompt_Q, inference_results_Q, actor_manager = setup_vllm_engines(
args, model_config, max_model_len=max_model_len
)

# Create the timestamp here so we use it for both filenames.
timestamp = int(time.time())
Expand Down
4 changes: 3 additions & 1 deletion open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,9 @@ def get_tokenizer_tulu_v2_1(tc: "TokenizerConfig"):


def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"):
config = AutoConfig.from_pretrained(tc.tokenizer_name_or_path, revision=tc.tokenizer_revision)
config = AutoConfig.from_pretrained(
tc.tokenizer_name_or_path, revision=tc.tokenizer_revision, trust_remote_code=tc.trust_remote_code
)
# @vwxyzjn: "olmo" handles both `olmo2` and `olmoe`.
if "olmo" in config.model_type:
if "olmo" in tc.chat_template_name:
Expand Down
26 changes: 25 additions & 1 deletion open_instruct/vllm_utils3.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def __init__(
engine_args.disable_cascade_attn = True

self.llm_engine = vllm.LLMEngine.from_engine_args(engine_args)
self.logger.info("initialized llmengine")

self.prompt_queue = prompt_queue
self.results_queue = results_queue
Expand All @@ -429,6 +430,9 @@ def __init__(
self._should_stop_value = False
self._should_stop_timeout_s = 5

# Logging interval for process_from_queue
self.log_interval = 1000

def _should_stop(self) -> bool:
if (time.perf_counter() - self._last_should_stop_update) > self._should_stop_timeout_s:
should_stop_ref = self.actor_manager.should_stop.remote()
Expand Down Expand Up @@ -479,8 +483,28 @@ def process_from_queue(self, timeout: float = 60.0):

tracking = _init_tool_tracking()
outputs = []
iteration = 0
process_start_time = time.perf_counter()

while True:
outputs.extend(self._poll_tool_futures(tracking, self.llm_engine.tokenizer))
iteration += 1

# Periodic logging
if iteration % self.log_interval == 0:
elapsed_time = time.perf_counter() - process_start_time
num_unfinished = self.llm_engine.get_num_unfinished_requests()
pending_tools = len(tracking["pending_tool_futures"]) if tracking else 0
self.logger.info(
f"[LLMRayActor] Status update - Iteration: {iteration}, "
f"Unfinished requests: {num_unfinished}, "
f"Pending tool futures: {pending_tools}, "
f"Outputs collected: {len(outputs)}, "
f"Elapsed time: {elapsed_time:.2f}s"
)

# Poll tool futures first (matching ToolUseLLM order)
if tracking and tracking.get("pending_tool_futures"):
outputs.extend(self._poll_tool_futures(tracking, self.llm_engine.tokenizer))

# Process engine steps - ONLY if there are unfinished requests (matching ToolUseLLM)
if self.llm_engine.has_unfinished_requests():
Expand Down
51 changes: 0 additions & 51 deletions scripts/gantry_run_benchmark.sh

This file was deleted.

Loading
Loading