@@ -6121,5 +6121,75 @@ def forward(self, x):
61216121 np .testing .assert_allclose (pytorch_output2 .numpy (), tvm_output2_np , rtol = 1e-4 , atol = 1e-5 )
61226122
61236123
6124+ def test_advanced_indexing_with_randn ():
6125+ """Test model with randn and advanced indexing write returning a tuple."""
6126+ N = 5
6127+
6128+ class AdvancedIndexingModel (nn .Module ):
6129+ def __init__ (self ):
6130+ super ().__init__ ()
6131+ self .elu = nn .ELU ()
6132+
6133+ def forward (self , x ):
6134+ L = torch .zeros (N , N , dtype = x .dtype , device = x .device )
6135+ idx = torch .arange (N , device = x .device )
6136+ v = torch .randn (N , device = x .device )
6137+ v = self .elu (v ) + 1.0 + 1e-8
6138+ L [idx , idx ] = v
6139+ y = x + 1
6140+ return y , L
6141+
6142+ torch .manual_seed (0 )
6143+ example_input = torch .randn (2 , N )
6144+ model = AdvancedIndexingModel ().eval ()
6145+
6146+ exported_program = export (model , (example_input ,))
6147+
6148+ mod = from_exported_program (exported_program )
6149+
6150+ @I .ir_module
6151+ class Expected :
6152+ @R .function
6153+ def main (
6154+ x : R .Tensor ((2 , 5 ), dtype = "float32" )
6155+ ) -> R .Tuple (R .Tensor ((2 , 5 ), dtype = "float32" ), R .Tensor ((5 , 5 ), dtype = "float32" )):
6156+ with R .dataflow ():
6157+ lv0 = R .zeros ((5 , 5 ), dtype = "float32" )
6158+
6159+ # Use zeros instead of random normal distribution
6160+ lv1 = R .zeros ((5 ,), dtype = "float32" )
6161+
6162+ lv2 = R .nn .elu (lv1 )
6163+ lv3 = R .add (lv2 , R .const (1.0 , "float32" ))
6164+ v = R .add (lv3 , R .const (1e-8 , "float32" ))
6165+
6166+ idx = R .arange (
6167+ R .const (0 , "int64" ), R .const (5 , "int64" ), R .const (1 , "int64" ), dtype = "int64"
6168+ )
6169+
6170+ L = R .tensor_update (lv0 , (idx , idx ), v )
6171+ y = R .add (x , R .const (1 , "float32" ))
6172+
6173+ gv = R .tuple (y , L )
6174+ R .output (gv )
6175+ return gv
6176+
6177+ tvm .ir .assert_structural_equal (mod , Expected )
6178+
6179+ target = "llvm"
6180+ dev = tvm .cpu ()
6181+
6182+ exe = relax .build (mod , target = target )
6183+ vm = relax .VirtualMachine (exe , dev )
6184+ tvm_res = vm ["main" ](tvm .nd .array (example_input .numpy ()))
6185+
6186+ torch_res = model (example_input )
6187+
6188+ np .testing .assert_allclose (torch_res [0 ].numpy (), tvm_res [0 ].numpy (), rtol = 1e-7 , atol = 1e-7 )
6189+
6190+ assert tvm_res [1 ].shape == (N , N )
6191+ assert tvm_res [1 ].dtype == "float32"
6192+
6193+
61246194if __name__ == "__main__" :
61256195 tvm .testing .main ()
0 commit comments