Skip to content

Commit ed0fb06

Browse files
committed
[#18362][relax.frontend.torch] Add temporary solution for pytorch op 'randn'
1 parent e28b510 commit ed0fb06

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,22 @@ def _reciprocal(self, node: fx.Node) -> relax.Var:
6464
x = self.env[node.args[0]]
6565
return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x))
6666

67+
def _randn(self, node: fx.Node) -> relax.Var:
68+
args = self.retrieve_args(node)
69+
70+
size = args[0] if isinstance(args[0], (list, tuple)) else args
71+
72+
dtype = node.kwargs.get("dtype", "float32")
73+
if isinstance(dtype, torch.dtype):
74+
dtype = self._convert_data_type(dtype)
75+
76+
shape = relax.ShapeExpr(size)
77+
78+
# TODO: This is a temporary solution that returns zeros instead of random values
79+
# since random initialization is mainly used during training, not inference.
80+
# This should be updated once Relax adds proper random number generation support.
81+
return self.block_builder.emit(relax.op.zeros(shape, dtype))
82+
6783
########## Neural Network ##########
6884

6985
def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
@@ -835,6 +851,7 @@ def create_convert_map(
835851
"pad.default": self._pad,
836852
"pixel_shuffle.default": self._pixel_shuffle,
837853
"prelu.default": self._prelu,
854+
"randn.default": self._randn,
838855
"reciprocal.default": self._reciprocal,
839856
"relu.default": self._unary_op(relax.op.nn.relu),
840857
"relu_.default": self._unary_op(relax.op.nn.relu),

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6121,5 +6121,75 @@ def forward(self, x):
61216121
np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5)
61226122

61236123

6124+
def test_advanced_indexing_with_randn():
6125+
"""Test model with randn and advanced indexing write returning a tuple."""
6126+
N = 5
6127+
6128+
class AdvancedIndexingModel(nn.Module):
6129+
def __init__(self):
6130+
super().__init__()
6131+
self.elu = nn.ELU()
6132+
6133+
def forward(self, x):
6134+
L = torch.zeros(N, N, dtype=x.dtype, device=x.device)
6135+
idx = torch.arange(N, device=x.device)
6136+
v = torch.randn(N, device=x.device)
6137+
v = self.elu(v) + 1.0 + 1e-8
6138+
L[idx, idx] = v
6139+
y = x + 1
6140+
return y, L
6141+
6142+
torch.manual_seed(0)
6143+
example_input = torch.randn(2, N)
6144+
model = AdvancedIndexingModel().eval()
6145+
6146+
exported_program = export(model, (example_input,))
6147+
6148+
mod = from_exported_program(exported_program)
6149+
6150+
@I.ir_module
6151+
class Expected:
6152+
@R.function
6153+
def main(
6154+
x: R.Tensor((2, 5), dtype="float32")
6155+
) -> R.Tuple(R.Tensor((2, 5), dtype="float32"), R.Tensor((5, 5), dtype="float32")):
6156+
with R.dataflow():
6157+
lv0 = R.zeros((5, 5), dtype="float32")
6158+
6159+
# Use zeros instead of random normal distribution
6160+
lv1 = R.zeros((5,), dtype="float32")
6161+
6162+
lv2 = R.nn.elu(lv1)
6163+
lv3 = R.add(lv2, R.const(1.0, "float32"))
6164+
v = R.add(lv3, R.const(1e-8, "float32"))
6165+
6166+
idx = R.arange(
6167+
R.const(0, "int64"), R.const(5, "int64"), R.const(1, "int64"), dtype="int64"
6168+
)
6169+
6170+
L = R.tensor_update(lv0, (idx, idx), v)
6171+
y = R.add(x, R.const(1, "float32"))
6172+
6173+
gv = R.tuple(y, L)
6174+
R.output(gv)
6175+
return gv
6176+
6177+
tvm.ir.assert_structural_equal(mod, Expected)
6178+
6179+
target = "llvm"
6180+
dev = tvm.cpu()
6181+
6182+
exe = relax.build(mod, target=target)
6183+
vm = relax.VirtualMachine(exe, dev)
6184+
tvm_res = vm["main"](tvm.nd.array(example_input.numpy()))
6185+
6186+
torch_res = model(example_input)
6187+
6188+
np.testing.assert_allclose(torch_res[0].numpy(), tvm_res[0].numpy(), rtol=1e-7, atol=1e-7)
6189+
6190+
assert tvm_res[1].shape == (N, N)
6191+
assert tvm_res[1].dtype == "float32"
6192+
6193+
61246194
if __name__ == "__main__":
61256195
tvm.testing.main()

0 commit comments

Comments
 (0)