Skip to content

Commit

Permalink
fix ut
Browse files Browse the repository at this point in the history
  • Loading branch information
BeingGod committed Aug 29, 2024
1 parent 352c675 commit e652a43
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,11 +1437,14 @@ def _generate_random_numbers(n, total_sum):
if isinstance(block, TorcGroupedLinearWithPadding):
out = block(inp_hidden_states, m_splits)
else:
padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
inp_hidden_states, m_splits
)
padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
if fp8:
padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
inp_hidden_states, m_splits
)
padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
else:
out = block(inp_hidden_states, m_splits)

loss = out.sum()
loss.backward()
Expand Down

0 comments on commit e652a43

Please sign in to comment.