diff --git a/axonn/tests/test_intra_layer_fc.py b/axonn/tests/test_intra_layer_fc.py index 72c8e3a..50d6069 100644 --- a/axonn/tests/test_intra_layer_fc.py +++ b/axonn/tests/test_intra_layer_fc.py @@ -76,9 +76,6 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, expert_mode, bias): assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match" - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - @pytest.mark.parametrize("B, H", [(32, 64), (16, 128), (2, 256)]) @pytest.mark.parametrize( @@ -201,8 +198,6 @@ def test_bw_pass( bias_grad_parallel, layer_sequential.bias.grad ), "BW Pass - gradients of bias do not match" - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() if __name__ == "__main__":