forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TestBatched.test_if_else_with_scalar.expect
50 lines (50 loc) · 1.98 KB
/
TestBatched.test_if_else_with_scalar.expect
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
graph(%a.1_data : Dynamic
%a.1_mask : Dynamic
%a.1_dims : Dynamic
%b_data : Dynamic
%b_mask : Dynamic
%b_dims : Dynamic) {
%6 : int = prim::Constant[value=1]()
%7 : float = prim::Constant[value=0.1]()
%8 : Float() = prim::NumToTensor(%7)
%other : float = prim::TensorToNum(%8)
%10 : Dynamic = aten::gt(%a.1_data, %other)
%11 : bool = prim::TensorToBool(%10)
%12 : Long() = prim::NumToTensor(%6)
%alpha.1 : float = prim::TensorToNum(%12)
%data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1)
%mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%17 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::TensorToNum(%17)
%data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
%mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
%dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%22 : bool = prim::Constant[value=1]()
%23 : int = prim::Constant[value=1]()
%24 : Dynamic = aten::type_as(%a.1_mask, %10)
%cond_mask.1 : Dynamic = aten::mul(%10, %24)
%26 : int = aten::dim(%cond_mask.1)
%27 : bool = aten::eq(%26, %23)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%27)
block0() {
%31 : int = aten::dim(%data.1)
%32 : int = aten::sub(%31, %23)
%data.3 : Dynamic = prim::Loop(%32, %22, %cond_mask.1)
block0(%_ : int, %35 : Dynamic) {
%36 : int = aten::dim(%35)
%data.2 : Dynamic = aten::unsqueeze(%35, %36)
-> (%22, %data.2)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)
-> (%cond_data.1, %cond_mask.2, %data.3)
}
block1() {
-> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
}
%res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4)
%res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask)
%res_dims : Dynamic = aten::__or__(%dims.1, %dims)
return (%res_data, %res_mask, %res_dims);
}