Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def run_compute_group_normalized_rewards(
group_size: int,
advantage_eps: float,
normalize_by_std: bool,
) -> tuple[torch.Tensor, dict[str, float]]:
) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
"""
Compute rewards for each group of rollout responses,
normalized by the group size.
Expand Down Expand Up @@ -90,7 +90,7 @@ def run_get_response_log_probs(
input_ids: torch.Tensor,
labels: torch.Tensor,
return_token_entropy: bool,
) -> torch.Tensor:
) -> dict[str, torch.Tensor]:
"""Get the conditional log-probs of the response given the prompt,
and optionally the entropy of the next token predictions.

Expand Down Expand Up @@ -199,7 +199,7 @@ def run_sft_microbatch_train_step(
policy_log_probs: torch.Tensor,
response_mask: torch.Tensor,
gradient_accumulation_steps: int,
normalize_constant: int | None = 1.0,
normalize_constant: float | None = 1.0,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Compute the policy gradient loss and backprop its gradients for a microbatch.
"""
Expand Down