@@ -67,10 +67,10 @@ def test_index_mul_float(self):
6767 loss = (out_ .float ()** 2 ).sum () / out_ .numel () + (force_ .float ()** 2 ).sum ()
6868 loss .backward ()
6969
70- self . assertTrue ( torch .allclose (self .input1_float , self .input1_float_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
71- self . assertTrue ( torch .allclose (self .input2_float , self .input2_float_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
72- self . assertTrue ( torch .allclose (self .input1_float .grad , self .input1_float_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
73- self . assertTrue ( torch .allclose (self .input2_float .grad , self .input2_float_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
70+ torch .testing . assert_close (self .input1_float , self .input1_float_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
71+ torch .testing . assert_close (self .input2_float , self .input2_float_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
72+ torch .testing . assert_close (self .input1_float .grad , self .input1_float_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
73+ torch .testing . assert_close (self .input2_float .grad , self .input2_float_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
7474
7575 def test_index_mul_half (self ):
7676 out = index_mul_2d (self .input1_half , self .input2_half , self .index1 )
@@ -95,10 +95,10 @@ def test_index_mul_half(self):
9595 loss = (out_ .float ()** 2 ).sum () / out_ .numel () + (force_ .float ()** 2 ).sum ()
9696 loss .backward ()
9797
98- self . assertTrue ( torch .allclose (self .input1_half , self .input1_half_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
99- self . assertTrue ( torch .allclose (self .input2_half , self .input2_half_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
100- self . assertTrue ( torch .allclose (self .input1_half .grad , self .input1_half_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
101- self . assertTrue ( torch .allclose (self .input2_half .grad , self .input2_half_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
98+ torch .testing . assert_close (self .input1_half , self .input1_half_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
99+ torch .testing . assert_close (self .input2_half , self .input2_half_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
100+ torch .testing . assert_close (self .input1_half .grad , self .input1_half_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
101+ torch .testing . assert_close (self .input2_half .grad , self .input2_half_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
102102
103103if __name__ == '__main__' :
104104 unittest .main ()
0 commit comments