diff --git a/tests/adapters.py b/tests/adapters.py index 22c6e76..b8c15f0 100644 --- a/tests/adapters.py +++ b/tests/adapters.py @@ -41,7 +41,7 @@ def run_compute_group_normalized_rewards( group_size: int, advantage_eps: float, normalize_by_std: bool, -) -> tuple[torch.Tensor, dict[str, float]]: +) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: """ Compute rewards for each group of rollout responses, normalized by the group size. @@ -90,7 +90,7 @@ def run_get_response_log_probs( input_ids: torch.Tensor, labels: torch.Tensor, return_token_entropy: bool, -) -> torch.Tensor: +) -> dict[str, torch.Tensor]: """Get the conditional log-probs of the response given the prompt, and optionally the entropy of the next token predictions. @@ -199,7 +199,7 @@ def run_sft_microbatch_train_step( policy_log_probs: torch.Tensor, response_mask: torch.Tensor, gradient_accumulation_steps: int, - normalize_constant: int | None = 1.0, + normalize_constant: float | None = 1.0, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Compute the policy gradient loss and backprop its gradients for a microbatch. """